/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/bert_tokenizer.h" #include #include #include "annotator/types.h" #include "utils/tokenizer-utils.h" #include "utils/utf8/unicodetext.h" #include "utils/utf8/unilib.h" #include "absl/strings/string_view.h" namespace libtextclassifier3 { namespace { int SafeLookup(const std::vector& vector, int index) { if (vector.empty()) { return 0; } index = std::max(index, 0); index = std::min(index, static_cast(vector.size()) - 1); return vector[index]; } } // namespace FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece( const std::vector& vocab) : vocab_{vocab} { for (int i = 0; i < vocab_.size(); ++i) { index_map_[vocab_[i]] = i; } } LookupStatus FlatHashMapBackedWordpiece::Contains(absl::string_view key, bool* value) const { *value = index_map_.contains(key); return LookupStatus(); } bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key, int* result) const { auto it = index_map_.find(key); if (it == index_map_.end()) { return false; } *result = it->second; return true; } bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id, absl::string_view* result) const { if (vocab_id >= vocab_.size() || vocab_id < 0) { return false; } *result = vocab_[vocab_id]; return true; } TokenizerResult BertTokenizer::Tokenize(const std::string& input) { return TokenizeIntoWordpieces(input); } WordpieceTokenizerResult BertTokenizer::TokenizeIntoWordpieces( const std::string& input) { std::vector tokens = TokenizeOnWhiteSpacePunctuationAndChineseLetter(input); return TokenizeIntoWordpieces(tokens); } WordpieceTokenizerResult BertTokenizer::TokenizeSingleToken( const std::string& token) { const UnicodeText token_unicode = UTF8ToUnicodeText(token, /*do_copy=*/false); std::vector tokens = { Token(token, 0, token_unicode.size_codepoints())}; return TokenizeIntoWordpieces(tokens); } WordpieceTokenizerResult BertTokenizer::TokenizeIntoWordpieces( const std::vector& tokens) { WordpieceTokenizerResult result; std::vector& subwords = result.subwords; for (int token_index = 0; token_index < tokens.size(); token_index++) { const Token& token = tokens[token_index]; int num_word_pieces = 0; std::vector wp_absolute_begin_offset; std::vector wp_absolute_end_offset; LookupStatus status = WordpieceTokenize( token.value, options_.max_bytes_per_token, options_.max_chars_per_subtoken, options_.suffix_indicator, options_.use_unknown_token, options_.unknown_token, options_.split_unknown_chars, &vocab_, &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset, &num_word_pieces); const UnicodeText token_unicode = UTF8ToUnicodeText(token.value, /*do_copy=*/false); std::vector byte_to_codepoint_offsets; int byte_to_codepoint_offset = 0; for (const auto& it : token_unicode.Codepoints()) { byte_to_codepoint_offsets.resize( it.utf8_data() + it.utf8_length() - token_unicode.data(), byte_to_codepoint_offset++); } byte_to_codepoint_offsets.push_back(byte_to_codepoint_offset); for (const int offset : wp_absolute_begin_offset) { result.wp_begin_offset.push_back( token.start + SafeLookup(byte_to_codepoint_offsets, offset)); } for (const int offset : wp_absolute_end_offset) { result.wp_end_offset.push_back( token.start + SafeLookup(byte_to_codepoint_offsets, offset)); } result.row_lengths.push_back(num_word_pieces); if (!status.success) { return result; } } return result; } // This replicates how the original bert_tokenizer from the tflite-support // library pretokenize text by using regex_split with these default regexes. // It splits the text on spaces, punctuations and chinese characters and // output all the tokens except spaces. // So far, the only difference between this and the original implementation // we are aware of is that the original regexes has 8 ranges of chinese // unicodes. We have all these 8 ranges plus two extra ranges. std::vector BertTokenizer::PreTokenize( const absl::string_view input) { const std::vector tokens = TokenizeOnWhiteSpacePunctuationAndChineseLetter(input); std::vector token_texts; std::transform(tokens.begin(), tokens.end(), std::back_inserter(token_texts), [](Token const& token) { return std::move(token.value); }); return token_texts; } } // namespace libtextclassifier3