xref: /aosp_15_r20/external/libtextclassifier/native/utils/bert_tokenizer.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/bert_tokenizer.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "annotator/types.h"
23 #include "utils/tokenizer-utils.h"
24 #include "utils/utf8/unicodetext.h"
25 #include "utils/utf8/unilib.h"
26 #include "absl/strings/string_view.h"
27 
28 namespace libtextclassifier3 {
29 
30 namespace {
31 
SafeLookup(const std::vector<int> & vector,int index)32 int SafeLookup(const std::vector<int>& vector, int index) {
33   if (vector.empty()) {
34     return 0;
35   }
36   index = std::max(index, 0);
37   index = std::min(index, static_cast<int>(vector.size()) - 1);
38   return vector[index];
39 }
40 
41 }  // namespace
42 
FlatHashMapBackedWordpiece(const std::vector<std::string> & vocab)43 FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
44     const std::vector<std::string>& vocab)
45     : vocab_{vocab} {
46   for (int i = 0; i < vocab_.size(); ++i) {
47     index_map_[vocab_[i]] = i;
48   }
49 }
50 
Contains(absl::string_view key,bool * value) const51 LookupStatus FlatHashMapBackedWordpiece::Contains(absl::string_view key,
52                                                   bool* value) const {
53   *value = index_map_.contains(key);
54   return LookupStatus();
55 }
56 
LookupId(const absl::string_view key,int * result) const57 bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key,
58                                           int* result) const {
59   auto it = index_map_.find(key);
60   if (it == index_map_.end()) {
61     return false;
62   }
63   *result = it->second;
64   return true;
65 }
66 
LookupWord(int vocab_id,absl::string_view * result) const67 bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id,
68                                             absl::string_view* result) const {
69   if (vocab_id >= vocab_.size() || vocab_id < 0) {
70     return false;
71   }
72   *result = vocab_[vocab_id];
73   return true;
74 }
75 
Tokenize(const std::string & input)76 TokenizerResult BertTokenizer::Tokenize(const std::string& input) {
77   return TokenizeIntoWordpieces(input);
78 }
79 
TokenizeIntoWordpieces(const std::string & input)80 WordpieceTokenizerResult BertTokenizer::TokenizeIntoWordpieces(
81     const std::string& input) {
82   std::vector<Token> tokens =
83       TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
84   return TokenizeIntoWordpieces(tokens);
85 }
86 
TokenizeSingleToken(const std::string & token)87 WordpieceTokenizerResult BertTokenizer::TokenizeSingleToken(
88     const std::string& token) {
89   const UnicodeText token_unicode = UTF8ToUnicodeText(token, /*do_copy=*/false);
90   std::vector<Token> tokens = {
91       Token(token, 0, token_unicode.size_codepoints())};
92   return TokenizeIntoWordpieces(tokens);
93 }
94 
TokenizeIntoWordpieces(const std::vector<Token> & tokens)95 WordpieceTokenizerResult BertTokenizer::TokenizeIntoWordpieces(
96     const std::vector<Token>& tokens) {
97   WordpieceTokenizerResult result;
98   std::vector<std::string>& subwords = result.subwords;
99 
100   for (int token_index = 0; token_index < tokens.size(); token_index++) {
101     const Token& token = tokens[token_index];
102     int num_word_pieces = 0;
103     std::vector<int> wp_absolute_begin_offset;
104     std::vector<int> wp_absolute_end_offset;
105     LookupStatus status = WordpieceTokenize(
106         token.value, options_.max_bytes_per_token,
107         options_.max_chars_per_subtoken, options_.suffix_indicator,
108         options_.use_unknown_token, options_.unknown_token,
109         options_.split_unknown_chars, &vocab_, &subwords,
110         &wp_absolute_begin_offset, &wp_absolute_end_offset, &num_word_pieces);
111     const UnicodeText token_unicode =
112         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
113 
114     std::vector<int> byte_to_codepoint_offsets;
115     int byte_to_codepoint_offset = 0;
116     for (const auto& it : token_unicode.Codepoints()) {
117       byte_to_codepoint_offsets.resize(
118           it.utf8_data() + it.utf8_length() - token_unicode.data(),
119           byte_to_codepoint_offset++);
120     }
121     byte_to_codepoint_offsets.push_back(byte_to_codepoint_offset);
122 
123     for (const int offset : wp_absolute_begin_offset) {
124       result.wp_begin_offset.push_back(
125           token.start + SafeLookup(byte_to_codepoint_offsets, offset));
126     }
127     for (const int offset : wp_absolute_end_offset) {
128       result.wp_end_offset.push_back(
129           token.start + SafeLookup(byte_to_codepoint_offsets, offset));
130     }
131     result.row_lengths.push_back(num_word_pieces);
132 
133     if (!status.success) {
134       return result;
135     }
136   }
137 
138   return result;
139 }
140 
141 // This replicates how the original bert_tokenizer from the tflite-support
142 // library pretokenize text by using regex_split with these default regexes.
143 // It splits the text on spaces, punctuations and chinese characters and
144 // output all the tokens except spaces.
145 // So far, the only difference between this and the original implementation
146 // we are aware of is that the original regexes has 8 ranges of chinese
147 // unicodes. We have all these 8 ranges plus two extra ranges.
PreTokenize(const absl::string_view input)148 std::vector<std::string> BertTokenizer::PreTokenize(
149     const absl::string_view input) {
150   const std::vector<Token> tokens =
151       TokenizeOnWhiteSpacePunctuationAndChineseLetter(input);
152   std::vector<std::string> token_texts;
153   std::transform(tokens.begin(), tokens.end(), std::back_inserter(token_texts),
154                  [](Token const& token) { return std::move(token.value); });
155 
156   return token_texts;
157 }
158 
159 }  // namespace libtextclassifier3
160