xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
17 
18 #include "absl/strings/str_join.h"
19 #include "absl/strings/str_split.h"
20 #include "tensorflow_lite_support/cc/port/status_macros.h"
21 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
22 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
23 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
24 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
25 
26 namespace tflite {
27 namespace task {
28 namespace text {
29 namespace qa {
30 
31 constexpr char kIdsTensorName[] = "ids";
32 constexpr char kMaskTensorName[] = "mask";
33 constexpr char kSegmentIdsTensorName[] = "segment_ids";
34 constexpr char kEndLogitsTensorName[] = "end_logits";
35 constexpr char kStartLogitsTensorName[] = "start_logits";
36 
37 using ::tflite::support::CreateStatusWithPayload;
38 using ::tflite::support::StatusOr;
39 using ::tflite::support::TfLiteSupportStatus;
40 using ::tflite::support::text::tokenizer::BertTokenizer;
41 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
42 using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
43 using ::tflite::support::text::tokenizer::TokenizerResult;
44 using ::tflite::task::core::FindTensorByName;
45 using ::tflite::task::core::PopulateTensor;
46 using ::tflite::task::core::PopulateVector;
47 using ::tflite::task::core::ReverseSortIndices;
48 
49 namespace {
50 constexpr int kTokenizerProcessUnitIndex = 0;
51 }
52 
53 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromFile(const std::string & path_to_model_with_metadata)54 BertQuestionAnswerer::CreateFromFile(
55     const std::string& path_to_model_with_metadata) {
56   std::unique_ptr<BertQuestionAnswerer> api_to_init;
57   ASSIGN_OR_RETURN(
58       api_to_init,
59       core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
60           path_to_model_with_metadata,
61           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
62           kNumLiteThreads));
63   RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
64   return api_to_init;
65 }
66 
67 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size)68 BertQuestionAnswerer::CreateFromBuffer(
69     const char* model_with_metadata_buffer_data,
70     size_t model_with_metadata_buffer_size) {
71   std::unique_ptr<BertQuestionAnswerer> api_to_init;
72   ASSIGN_OR_RETURN(
73       api_to_init,
74       core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
75           model_with_metadata_buffer_data, model_with_metadata_buffer_size,
76           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
77           kNumLiteThreads));
78   RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
79   return api_to_init;
80 }
81 
CreateFromFd(int fd)82 StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
83     int fd) {
84   std::unique_ptr<BertQuestionAnswerer> api_to_init;
85   ASSIGN_OR_RETURN(
86       api_to_init,
87       core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>(
88           fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
89           kNumLiteThreads));
90   RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
91   return api_to_init;
92 }
93 
94 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromFile(const std::string & path_to_model,const std::string & path_to_vocab)95 BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
96     const std::string& path_to_model, const std::string& path_to_vocab) {
97   std::unique_ptr<BertQuestionAnswerer> api_to_init;
98   ASSIGN_OR_RETURN(
99       api_to_init,
100       core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
101           path_to_model,
102           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
103           kNumLiteThreads));
104   api_to_init->InitializeBertTokenizer(path_to_vocab);
105   return api_to_init;
106 }
107 
108 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromBuffer(const char * model_buffer_data,size_t model_buffer_size,const char * vocab_buffer_data,size_t vocab_buffer_size)109 BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
110     const char* model_buffer_data, size_t model_buffer_size,
111     const char* vocab_buffer_data, size_t vocab_buffer_size) {
112   std::unique_ptr<BertQuestionAnswerer> api_to_init;
113   ASSIGN_OR_RETURN(
114       api_to_init,
115       core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
116           model_buffer_data, model_buffer_size,
117           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
118           kNumLiteThreads));
119   api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data,
120                                                  vocab_buffer_size);
121   return api_to_init;
122 }
123 
124 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromFile(const std::string & path_to_model,const std::string & path_to_spmodel)125 BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
126     const std::string& path_to_model, const std::string& path_to_spmodel) {
127   std::unique_ptr<BertQuestionAnswerer> api_to_init;
128   ASSIGN_OR_RETURN(
129       api_to_init,
130       core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
131           path_to_model,
132           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
133           kNumLiteThreads));
134   api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel);
135   return api_to_init;
136 }
137 
138 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromBuffer(const char * model_buffer_data,size_t model_buffer_size,const char * spmodel_buffer_data,size_t spmodel_buffer_size)139 BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
140     const char* model_buffer_data, size_t model_buffer_size,
141     const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
142   std::unique_ptr<BertQuestionAnswerer> api_to_init;
143   ASSIGN_OR_RETURN(
144       api_to_init,
145       core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
146           model_buffer_data, model_buffer_size,
147           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
148           kNumLiteThreads));
149   api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data,
150                                                           spmodel_buffer_size);
151   return api_to_init;
152 }
153 
Answer(const std::string & context,const std::string & question)154 std::vector<QaAnswer> BertQuestionAnswerer::Answer(
155     const std::string& context, const std::string& question) {
156   // The BertQuestionAnswererer implementation for Preprocess() and
157   // Postprocess() never returns errors: just call value().
158   return Infer(context, question).value();
159 }
160 
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & context,const std::string & query)161 absl::Status BertQuestionAnswerer::Preprocess(
162     const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
163     const std::string& query) {
164   auto* input_tensor_metadatas =
165       GetMetadataExtractor()->GetInputTensorMetadata();
166   TfLiteTensor* ids_tensor =
167       input_tensor_metadatas
168           ? FindTensorByName(input_tensors, input_tensor_metadatas,
169                              kIdsTensorName)
170           : input_tensors[0];
171   TfLiteTensor* mask_tensor =
172       input_tensor_metadatas
173           ? FindTensorByName(input_tensors, input_tensor_metadatas,
174                              kMaskTensorName)
175           : input_tensors[1];
176   TfLiteTensor* segment_ids_tensor =
177       input_tensor_metadatas
178           ? FindTensorByName(input_tensors, input_tensor_metadatas,
179                              kSegmentIdsTensorName)
180           : input_tensors[2];
181 
182   token_to_orig_map_.clear();
183 
184   // The orig_tokens is used for recovering the answer string from the index,
185   // while the processed_tokens is lower-cased and used to generate input of
186   // the model.
187   orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty());
188   std::vector<std::string> processed_tokens(orig_tokens_);
189 
190   std::string processed_query = query;
191   if (kUseLowerCase) {
192     for (auto& token : processed_tokens) {
193       absl::AsciiStrToLower(&token);
194     }
195     absl::AsciiStrToLower(&processed_query);
196   }
197 
198   TokenizerResult query_tokenize_results;
199   query_tokenize_results = tokenizer_->Tokenize(processed_query);
200 
201   std::vector<std::string> query_tokens = query_tokenize_results.subwords;
202   if (query_tokens.size() > kMaxQueryLen) {
203     query_tokens.resize(kMaxQueryLen);
204   }
205 
206   // Example:
207   // context:             tokenize     me  please
208   // all_doc_tokens:      token ##ize  me  plea ##se
209   // token_to_orig_index: [0,   0,     1,  2,   2]
210 
211   std::vector<std::string> all_doc_tokens;
212   std::vector<int> token_to_orig_index;
213   for (size_t i = 0; i < processed_tokens.size(); i++) {
214     const std::string& token = processed_tokens[i];
215     std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords;
216     for (const std::string& sub_token : sub_tokens) {
217       token_to_orig_index.emplace_back(i);
218       all_doc_tokens.emplace_back(sub_token);
219     }
220   }
221 
222   // -3 accounts for [CLS], [SEP] and [SEP].
223   int max_context_len = kMaxSeqLen - query_tokens.size() - 3;
224   if (all_doc_tokens.size() > max_context_len) {
225     all_doc_tokens.resize(max_context_len);
226   }
227 
228   std::vector<std::string> tokens;
229   tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size());
230   std::vector<int> segment_ids;
231   segment_ids.reserve(kMaxSeqLen);
232 
233   // Start of generating the features.
234   tokens.emplace_back("[CLS]");
235   segment_ids.emplace_back(0);
236 
237   // For query input.
238   for (const auto& query_token : query_tokens) {
239     tokens.emplace_back(query_token);
240     segment_ids.emplace_back(0);
241   }
242 
243   // For Separation.
244   tokens.emplace_back("[SEP]");
245   segment_ids.emplace_back(0);
246 
247   // For Text Input.
248   for (int i = 0; i < all_doc_tokens.size(); i++) {
249     auto& doc_token = all_doc_tokens[i];
250     tokens.emplace_back(doc_token);
251     segment_ids.emplace_back(1);
252     token_to_orig_map_[tokens.size()] = token_to_orig_index[i];
253   }
254 
255   // For ending mark.
256   tokens.emplace_back("[SEP]");
257   segment_ids.emplace_back(1);
258 
259   std::vector<int> input_ids(tokens.size());
260   input_ids.reserve(kMaxSeqLen);
261   // Convert tokens back into ids
262   for (int i = 0; i < tokens.size(); i++) {
263     auto& token = tokens[i];
264     tokenizer_->LookupId(token, &input_ids[i]);
265   }
266 
267   std::vector<int> input_mask;
268   input_mask.reserve(kMaxSeqLen);
269   input_mask.insert(input_mask.end(), tokens.size(), 1);
270 
271   int zeros_to_pad = kMaxSeqLen - input_ids.size();
272   input_ids.insert(input_ids.end(), zeros_to_pad, 0);
273   input_mask.insert(input_mask.end(), zeros_to_pad, 0);
274   segment_ids.insert(segment_ids.end(), zeros_to_pad, 0);
275 
276   // input_ids INT32[1, 384]
277   PopulateTensor(input_ids, ids_tensor);
278   // input_mask INT32[1, 384]
279   PopulateTensor(input_mask, mask_tensor);
280   // segment_ids INT32[1, 384]
281   PopulateTensor(segment_ids, segment_ids_tensor);
282 
283   return absl::OkStatus();
284 }
285 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &,const std::string &)286 StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess(
287     const std::vector<const TfLiteTensor*>& output_tensors,
288     const std::string& /*lowercased_context*/,
289     const std::string& /*lowercased_query*/) {
290   auto* output_tensor_metadatas =
291       GetMetadataExtractor()->GetOutputTensorMetadata();
292 
293   const TfLiteTensor* end_logits_tensor =
294       output_tensor_metadatas
295           ? FindTensorByName(output_tensors, output_tensor_metadatas,
296                              kEndLogitsTensorName)
297           : output_tensors[0];
298   const TfLiteTensor* start_logits_tensor =
299       output_tensor_metadatas
300           ? FindTensorByName(output_tensors, output_tensor_metadatas,
301                              kStartLogitsTensorName)
302           : output_tensors[1];
303 
304   std::vector<float> end_logits;
305   std::vector<float> start_logits;
306 
307   // end_logits FLOAT[1, 384]
308   PopulateVector(end_logits_tensor, &end_logits);
309   // start_logits FLOAT[1, 384]
310   PopulateVector(start_logits_tensor, &start_logits);
311 
312   auto start_indices = ReverseSortIndices(start_logits);
313   auto end_indices = ReverseSortIndices(end_logits);
314 
315   std::vector<QaAnswer::Pos> orig_results;
316   for (int start_index = 0; start_index < kPredictAnsNum; start_index++) {
317     for (int end_index = 0; end_index < kPredictAnsNum; end_index++) {
318       int start = start_indices[start_index];
319       int end = end_indices[end_index];
320 
321       if (!token_to_orig_map_.contains(start + kOutputOffset) ||
322           !token_to_orig_map_.contains(end + kOutputOffset) || end < start ||
323           (end - start + 1) > kMaxAnsLen) {
324         continue;
325       }
326       orig_results.emplace_back(
327           QaAnswer::Pos(start, end, start_logits[start] + end_logits[end]));
328     }
329   }
330 
331   std::sort(orig_results.begin(), orig_results.end());
332 
333   std::vector<QaAnswer> answers;
334   for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) {
335     auto orig_pos = orig_results[i];
336     answers.emplace_back(
337         orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end)
338                            : "",
339         orig_pos);
340   }
341 
342   return answers;
343 }
344 
ConvertIndexToString(int start,int end)345 std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) {
346   int start_index = token_to_orig_map_[start + kOutputOffset];
347   int end_index = token_to_orig_map_[end + kOutputOffset];
348 
349   return absl::StrJoin(orig_tokens_.begin() + start_index,
350                        orig_tokens_.begin() + end_index + 1, " ");
351 }
352 
InitializeFromMetadata()353 absl::Status BertQuestionAnswerer::InitializeFromMetadata() {
354   const ProcessUnit* tokenizer_process_unit =
355       GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
356   if (tokenizer_process_unit == nullptr) {
357     return CreateStatusWithPayload(
358         absl::StatusCode::kInvalidArgument,
359         "No input process unit found from metadata.",
360         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
361   }
362   ASSIGN_OR_RETURN(tokenizer_,
363                    CreateTokenizerFromProcessUnit(tokenizer_process_unit,
364                                                   GetMetadataExtractor()));
365   return absl::OkStatus();
366 }
367 
InitializeBertTokenizer(const std::string & path_to_vocab)368 void BertQuestionAnswerer::InitializeBertTokenizer(
369     const std::string& path_to_vocab) {
370   tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab);
371 }
372 
InitializeBertTokenizerFromBinary(const char * vocab_buffer_data,size_t vocab_buffer_size)373 void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
374     const char* vocab_buffer_data, size_t vocab_buffer_size) {
375   tokenizer_ =
376       absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
377 }
378 
InitializeSentencepieceTokenizer(const std::string & path_to_spmodel)379 void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
380     const std::string& path_to_spmodel) {
381   tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel);
382 }
383 
InitializeSentencepieceTokenizerFromBinary(const char * spmodel_buffer_data,size_t spmodel_buffer_size)384 void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
385     const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
386   tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
387                                                          spmodel_buffer_size);
388 }
389 
390 }  // namespace qa
391 }  // namespace text
392 }  // namespace task
393 }  // namespace tflite
394