1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/status/status.h" 21 #include "tensorflow_lite_support/cc/port/statusor.h" 22 #include "tensorflow_lite_support/cc/task/core/base_task_api.h" 23 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" 24 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" 25 #include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" 26 #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" 27 #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" 28 29 namespace tflite { 30 namespace task { 31 namespace text { 32 namespace qa { 33 34 // BertQA task API, performs tokenization for models (BERT, Albert, etc.) in 35 // preprocess and returns most possible answers. 36 // 37 // In particular, the branch of BERT models use WordPiece tokenizer, and the 38 // branch of Albert models use SentencePiece tokenizer, respectively. 39 // 40 // Factory methods: 41 // CreateFromFile(path_to_model_with_metadata) 42 // CreateFromBuffer(model_with_metadata_buffer_data, 43 // model_with_metadata_buffer_size) 44 // CreateFromFd(file_descriptor_to_model_with_metadata) 45 // Generic API to create the QuestionAnswerer for bert models with metadata 46 // populated. The API expects a Bert based TFLite model with metadata 47 // containing the following information: 48 // - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece 49 // Tokenizer can be used for a MobileBert[0] model, Sentencepiece 50 // Tokenizer Tokenizer can be used for an Albert[1] model 51 // - 3 input tensors with names "ids", "mask" and "segment_ids" 52 // - 2 output tensors with names "end_logits" and "start_logits" 53 // [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 54 // [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 55 // 56 // CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab) 57 // Creates a BertQuestionAnswerer from TFLite model file and vocab file for 58 // WordPiece tokenizer. Used in C++ environment. 59 // One suitable model is: 60 // https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 61 // 62 // CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size, 63 // vocab_buffer_data, vocab_buffer_size) 64 // Creates a BertQuestionAnswerer from TFLite model buffer and vocab file 65 // buffer for WordPiece tokenizer. Used in Jave (JNI) environment. 66 // 67 // CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel) 68 // Creates an AlbertQuestionAnswerer from TFLite model file and 69 // SentencePiece model file. Used in C++ environment. 70 // One suitable model is: 71 // https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 72 // 73 // CreateAlbertQuestionAnswererFromBuffer(model_buffer_data, 74 // model_buffer_size, 75 // spmodel_buffer_data, 76 // spmodel_buffer_size) 77 // Creates an AlbertQuestionAnswerer from TFLite model file buffer and 78 // SentencePiece model file buffer. Used in Jave (JNI) environment. 79 // 80 81 class BertQuestionAnswerer : public QuestionAnswerer { 82 public: 83 // TODO(b/150904655): add support to parameterize. 84 static constexpr int kMaxQueryLen = 64; 85 static constexpr int kMaxSeqLen = 384; 86 static constexpr int kPredictAnsNum = 5; 87 static constexpr int kMaxAnsLen = 32; 88 // TODO(b/151954803): clarify the offset usage 89 static constexpr int kOutputOffset = 1; 90 static constexpr int kNumLiteThreads = 4; 91 static constexpr bool kUseLowerCase = true; 92 93 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 94 CreateFromFile(const std::string& path_to_model_with_metadata); 95 96 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 97 CreateFromBuffer(const char* model_with_metadata_buffer_data, 98 size_t model_with_metadata_buffer_size); 99 100 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 101 CreateFromFd(int fd); 102 103 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 104 CreateBertQuestionAnswererFromFile(const std::string& path_to_model, 105 const std::string& path_to_vocab); 106 107 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 108 CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data, 109 size_t model_buffer_size, 110 const char* vocab_buffer_data, 111 size_t vocab_buffer_size); 112 113 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 114 CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model, 115 const std::string& path_to_spmodel); 116 117 static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> 118 CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data, 119 size_t model_buffer_size, 120 const char* spmodel_buffer_data, 121 size_t spmodel_buffer_size); 122 BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)123 explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) 124 : QuestionAnswerer(std::move(engine)) {} 125 126 // Answers question based on the context. Could be empty if no answer was 127 // found from the given context. 128 std::vector<QaAnswer> Answer(const std::string& context, 129 const std::string& question) override; 130 131 private: 132 absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, 133 const std::string& lowercased_context, 134 const std::string& lowercased_query) override; 135 136 tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess( 137 const std::vector<const TfLiteTensor*>& output_tensors, 138 const std::string& lowercased_context, 139 const std::string& lowercased_query) override; 140 141 // Initialize API with a BertTokenizer from the vocabulary file. 142 void InitializeBertTokenizer(const std::string& path_to_vocab); 143 // Initialize API with a BertTokenizer from the vocabulary buffer. 144 void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data, 145 size_t vocab_buffer_size); 146 147 // Initialize API with a SentencepieceTokenizer from the model file. 148 void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel); 149 // Initialize API with a SentencepieceTokenizer from the model buffer. 150 void InitializeSentencepieceTokenizerFromBinary( 151 const char* spmodel_buffer_data, size_t spmodel_buffer_size); 152 153 // Initialize the API with the tokenizer set in the metadata. 154 absl::Status InitializeFromMetadata(); 155 156 std::string ConvertIndexToString(int start, int end); 157 158 std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; 159 // Maps index of input token to index of untokenized word from original input. 160 absl::flat_hash_map<size_t, size_t> token_to_orig_map_; 161 // Original tokens of context. 162 std::vector<std::string> orig_tokens_; 163 }; 164 165 } // namespace qa 166 } // namespace text 167 } // namespace task 168 } // namespace tflite 169 170 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ 171