xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/status/status.h"
21 #include "tensorflow_lite_support/cc/port/statusor.h"
22 #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
23 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
24 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
25 #include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h"
26 #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
27 #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
28 
29 namespace tflite {
30 namespace task {
31 namespace text {
32 namespace qa {
33 
34 // BertQA task API, performs tokenization for models (BERT, Albert, etc.) in
35 // preprocess and returns most possible answers.
36 //
37 // In particular, the branch of BERT models use WordPiece tokenizer, and the
38 // branch of Albert models use SentencePiece tokenizer, respectively.
39 //
40 // Factory methods:
41 //   CreateFromFile(path_to_model_with_metadata)
42 //   CreateFromBuffer(model_with_metadata_buffer_data,
43 //                                model_with_metadata_buffer_size)
44 //   CreateFromFd(file_descriptor_to_model_with_metadata)
45 //     Generic API to create the QuestionAnswerer for bert models with metadata
46 //     populated. The API expects a Bert based TFLite model with metadata
47 //     containing the following information:
48 //       - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece
49 //         Tokenizer can be used for a MobileBert[0] model, Sentencepiece
50 //         Tokenizer Tokenizer can be used for an Albert[1] model
51 //       - 3 input tensors with names "ids", "mask" and "segment_ids"
52 //       - 2 output tensors with names "end_logits" and "start_logits"
53 //      [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
54 //      [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
55 //
56 //   CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab)
57 //     Creates a BertQuestionAnswerer from TFLite model file and vocab file for
58 //     WordPiece tokenizer. Used in C++ environment.
59 //     One suitable model is:
60 //       https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
61 //
62 //   CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size,
63 //                                        vocab_buffer_data, vocab_buffer_size)
64 //     Creates a BertQuestionAnswerer from TFLite model buffer and vocab file
65 //     buffer for WordPiece tokenizer. Used in Jave (JNI) environment.
66 //
67 //   CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel)
68 //     Creates an AlbertQuestionAnswerer from TFLite model file and
69 //     SentencePiece model file. Used in C++ environment.
70 //     One suitable model is:
71 //       https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
72 //
73 //   CreateAlbertQuestionAnswererFromBuffer(model_buffer_data,
74 //                                          model_buffer_size,
75 //                                          spmodel_buffer_data,
76 //                                          spmodel_buffer_size)
77 //     Creates an AlbertQuestionAnswerer from TFLite model file buffer and
78 //     SentencePiece model file buffer. Used in Jave (JNI) environment.
79 //
80 
81 class BertQuestionAnswerer : public QuestionAnswerer {
82  public:
83   // TODO(b/150904655): add support to parameterize.
84   static constexpr int kMaxQueryLen = 64;
85   static constexpr int kMaxSeqLen = 384;
86   static constexpr int kPredictAnsNum = 5;
87   static constexpr int kMaxAnsLen = 32;
88   // TODO(b/151954803): clarify the offset usage
89   static constexpr int kOutputOffset = 1;
90   static constexpr int kNumLiteThreads = 4;
91   static constexpr bool kUseLowerCase = true;
92 
93   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
94   CreateFromFile(const std::string& path_to_model_with_metadata);
95 
96   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
97   CreateFromBuffer(const char* model_with_metadata_buffer_data,
98                    size_t model_with_metadata_buffer_size);
99 
100   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
101   CreateFromFd(int fd);
102 
103   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
104   CreateBertQuestionAnswererFromFile(const std::string& path_to_model,
105                                      const std::string& path_to_vocab);
106 
107   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
108   CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data,
109                                        size_t model_buffer_size,
110                                        const char* vocab_buffer_data,
111                                        size_t vocab_buffer_size);
112 
113   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
114   CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model,
115                                        const std::string& path_to_spmodel);
116 
117   static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
118   CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data,
119                                          size_t model_buffer_size,
120                                          const char* spmodel_buffer_data,
121                                          size_t spmodel_buffer_size);
122 
BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)123   explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
124       : QuestionAnswerer(std::move(engine)) {}
125 
126   // Answers question based on the context. Could be empty if no answer was
127   // found from the given context.
128   std::vector<QaAnswer> Answer(const std::string& context,
129                                const std::string& question) override;
130 
131  private:
132   absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
133                           const std::string& lowercased_context,
134                           const std::string& lowercased_query) override;
135 
136   tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess(
137       const std::vector<const TfLiteTensor*>& output_tensors,
138       const std::string& lowercased_context,
139       const std::string& lowercased_query) override;
140 
141   // Initialize API with a BertTokenizer from the vocabulary file.
142   void InitializeBertTokenizer(const std::string& path_to_vocab);
143   // Initialize API with a BertTokenizer from the vocabulary buffer.
144   void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data,
145                                          size_t vocab_buffer_size);
146 
147   // Initialize API with a SentencepieceTokenizer from the model file.
148   void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
149   // Initialize API with a SentencepieceTokenizer from the model buffer.
150   void InitializeSentencepieceTokenizerFromBinary(
151       const char* spmodel_buffer_data, size_t spmodel_buffer_size);
152 
153   // Initialize the API with the tokenizer set in the metadata.
154   absl::Status InitializeFromMetadata();
155 
156   std::string ConvertIndexToString(int start, int end);
157 
158   std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
159   // Maps index of input token to index of untokenized word from original input.
160   absl::flat_hash_map<size_t, size_t> token_to_orig_map_;
161   // Original tokens of context.
162   std::vector<std::string> orig_tokens_;
163 };
164 
165 }  // namespace qa
166 }  // namespace text
167 }  // namespace task
168 }  // namespace tflite
169 
170 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
171