xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/string_projection.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/string_projection.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <string>
20*993b0882SAndroid Build Coastguard Worker #include <unordered_map>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "utils/strings/utf8.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/string_projection_base.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib-common.h"
25*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h"
26*993b0882SAndroid Build Coastguard Worker #include "absl/strings/match.h"
27*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flexbuffers.h"
28*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/context.h"
29*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker namespace tflite {
32*993b0882SAndroid Build Coastguard Worker namespace ops {
33*993b0882SAndroid Build Coastguard Worker namespace custom {
34*993b0882SAndroid Build Coastguard Worker 
35*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
36*993b0882SAndroid Build Coastguard Worker namespace string_projection {
37*993b0882SAndroid Build Coastguard Worker namespace {
38*993b0882SAndroid Build Coastguard Worker 
39*993b0882SAndroid Build Coastguard Worker const char kStartToken[] = "<S>";
40*993b0882SAndroid Build Coastguard Worker const char kEndToken[] = "<E>";
41*993b0882SAndroid Build Coastguard Worker const char kEmptyToken[] = "<S> <E>";
42*993b0882SAndroid Build Coastguard Worker constexpr size_t kEntireString = SIZE_MAX;
43*993b0882SAndroid Build Coastguard Worker constexpr size_t kAllTokens = SIZE_MAX;
44*993b0882SAndroid Build Coastguard Worker constexpr int kInvalid = -1;
45*993b0882SAndroid Build Coastguard Worker 
46*993b0882SAndroid Build Coastguard Worker constexpr char kApostrophe = '\'';
47*993b0882SAndroid Build Coastguard Worker constexpr char kSpace = ' ';
48*993b0882SAndroid Build Coastguard Worker constexpr char kComma = ',';
49*993b0882SAndroid Build Coastguard Worker constexpr char kDot = '.';
50*993b0882SAndroid Build Coastguard Worker 
51*993b0882SAndroid Build Coastguard Worker // Returns true if the given text contains a number.
IsDigitString(const std::string & text)52*993b0882SAndroid Build Coastguard Worker bool IsDigitString(const std::string& text) {
53*993b0882SAndroid Build Coastguard Worker   for (size_t i = 0; i < text.length();) {
54*993b0882SAndroid Build Coastguard Worker     const int bytes_read =
55*993b0882SAndroid Build Coastguard Worker         ::libtextclassifier3::GetNumBytesForUTF8Char(text.data());
56*993b0882SAndroid Build Coastguard Worker     if (bytes_read <= 0 || bytes_read > text.length() - i) {
57*993b0882SAndroid Build Coastguard Worker       break;
58*993b0882SAndroid Build Coastguard Worker     }
59*993b0882SAndroid Build Coastguard Worker     const char32_t rune = ::libtextclassifier3::ValidCharToRune(text.data());
60*993b0882SAndroid Build Coastguard Worker     if (::libtextclassifier3::IsDigit(rune)) return true;
61*993b0882SAndroid Build Coastguard Worker     i += bytes_read;
62*993b0882SAndroid Build Coastguard Worker   }
63*993b0882SAndroid Build Coastguard Worker   return false;
64*993b0882SAndroid Build Coastguard Worker }
65*993b0882SAndroid Build Coastguard Worker 
66*993b0882SAndroid Build Coastguard Worker // Gets the string containing |num_chars| characters from |start| position.
GetCharToken(const std::vector<std::string> & char_tokens,int start,int num_chars)67*993b0882SAndroid Build Coastguard Worker std::string GetCharToken(const std::vector<std::string>& char_tokens, int start,
68*993b0882SAndroid Build Coastguard Worker                          int num_chars) {
69*993b0882SAndroid Build Coastguard Worker   std::string char_token = "";
70*993b0882SAndroid Build Coastguard Worker   if (start + num_chars <= char_tokens.size()) {
71*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < num_chars; ++i) {
72*993b0882SAndroid Build Coastguard Worker       char_token.append(char_tokens[start + i]);
73*993b0882SAndroid Build Coastguard Worker     }
74*993b0882SAndroid Build Coastguard Worker   }
75*993b0882SAndroid Build Coastguard Worker   return char_token;
76*993b0882SAndroid Build Coastguard Worker }
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker // Counts how many times |pattern| appeared from |start| position.
GetNumPattern(const std::vector<std::string> & char_tokens,size_t start,size_t num_chars,const std::string & pattern)79*993b0882SAndroid Build Coastguard Worker int GetNumPattern(const std::vector<std::string>& char_tokens, size_t start,
80*993b0882SAndroid Build Coastguard Worker                   size_t num_chars, const std::string& pattern) {
81*993b0882SAndroid Build Coastguard Worker   int count = 0;
82*993b0882SAndroid Build Coastguard Worker   for (int i = start; i < char_tokens.size(); i += num_chars) {
83*993b0882SAndroid Build Coastguard Worker     std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
84*993b0882SAndroid Build Coastguard Worker     if (pattern == cur_pattern) {
85*993b0882SAndroid Build Coastguard Worker       ++count;
86*993b0882SAndroid Build Coastguard Worker     } else {
87*993b0882SAndroid Build Coastguard Worker       break;
88*993b0882SAndroid Build Coastguard Worker     }
89*993b0882SAndroid Build Coastguard Worker   }
90*993b0882SAndroid Build Coastguard Worker   return count;
91*993b0882SAndroid Build Coastguard Worker }
92*993b0882SAndroid Build Coastguard Worker 
FindNextSpace(const char * input_ptr,size_t from,size_t length)93*993b0882SAndroid Build Coastguard Worker inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
94*993b0882SAndroid Build Coastguard Worker   size_t space_index;
95*993b0882SAndroid Build Coastguard Worker   for (space_index = from; space_index < length; space_index++) {
96*993b0882SAndroid Build Coastguard Worker     if (input_ptr[space_index] == kSpace) {
97*993b0882SAndroid Build Coastguard Worker       break;
98*993b0882SAndroid Build Coastguard Worker     }
99*993b0882SAndroid Build Coastguard Worker   }
100*993b0882SAndroid Build Coastguard Worker   return space_index == length ? kInvalid : space_index;
101*993b0882SAndroid Build Coastguard Worker }
102*993b0882SAndroid Build Coastguard Worker 
103*993b0882SAndroid Build Coastguard Worker template <typename T>
SplitByCharInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_tokens)104*993b0882SAndroid Build Coastguard Worker void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
105*993b0882SAndroid Build Coastguard Worker                          size_t len, size_t max_tokens) {
106*993b0882SAndroid Build Coastguard Worker   for (size_t i = 0; i < len;) {
107*993b0882SAndroid Build Coastguard Worker     auto bytes_read =
108*993b0882SAndroid Build Coastguard Worker         ::libtextclassifier3::GetNumBytesForUTF8Char(input_ptr + i);
109*993b0882SAndroid Build Coastguard Worker     if (bytes_read <= 0 || bytes_read > len - i) break;
110*993b0882SAndroid Build Coastguard Worker     tokens->emplace_back(input_ptr + i, bytes_read);
111*993b0882SAndroid Build Coastguard Worker     if (max_tokens != kInvalid && tokens->size() == max_tokens) {
112*993b0882SAndroid Build Coastguard Worker       break;
113*993b0882SAndroid Build Coastguard Worker     }
114*993b0882SAndroid Build Coastguard Worker     i += bytes_read;
115*993b0882SAndroid Build Coastguard Worker   }
116*993b0882SAndroid Build Coastguard Worker }
117*993b0882SAndroid Build Coastguard Worker 
SplitByChar(const char * input_ptr,size_t len,size_t max_tokens)118*993b0882SAndroid Build Coastguard Worker std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
119*993b0882SAndroid Build Coastguard Worker                                      size_t max_tokens) {
120*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> tokens;
121*993b0882SAndroid Build Coastguard Worker   SplitByCharInternal(&tokens, input_ptr, len, max_tokens);
122*993b0882SAndroid Build Coastguard Worker   return tokens;
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker 
ContractToken(const char * input_ptr,size_t len,size_t num_chars)125*993b0882SAndroid Build Coastguard Worker std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
126*993b0882SAndroid Build Coastguard Worker   // This function contracts patterns whose length is |num_chars| and appeared
127*993b0882SAndroid Build Coastguard Worker   // more than twice. So if the input is shorter than 3 * |num_chars|, do not
128*993b0882SAndroid Build Coastguard Worker   // apply any contraction.
129*993b0882SAndroid Build Coastguard Worker   if (len < 3 * num_chars) {
130*993b0882SAndroid Build Coastguard Worker     return input_ptr;
131*993b0882SAndroid Build Coastguard Worker   }
132*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> char_tokens = SplitByChar(input_ptr, len, len);
133*993b0882SAndroid Build Coastguard Worker 
134*993b0882SAndroid Build Coastguard Worker   std::string token;
135*993b0882SAndroid Build Coastguard Worker   token.reserve(len);
136*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < char_tokens.size();) {
137*993b0882SAndroid Build Coastguard Worker     std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
138*993b0882SAndroid Build Coastguard Worker 
139*993b0882SAndroid Build Coastguard Worker     // Count how many times this pattern appeared.
140*993b0882SAndroid Build Coastguard Worker     int num_cur_patterns = 0;
141*993b0882SAndroid Build Coastguard Worker     if (!absl::StrContains(cur_pattern, " ") && !IsDigitString(cur_pattern)) {
142*993b0882SAndroid Build Coastguard Worker       num_cur_patterns =
143*993b0882SAndroid Build Coastguard Worker           GetNumPattern(char_tokens, i + num_chars, num_chars, cur_pattern);
144*993b0882SAndroid Build Coastguard Worker     }
145*993b0882SAndroid Build Coastguard Worker 
146*993b0882SAndroid Build Coastguard Worker     if (num_cur_patterns >= 2) {
147*993b0882SAndroid Build Coastguard Worker       // If this pattern is repeated, store it only twice.
148*993b0882SAndroid Build Coastguard Worker       token.append(cur_pattern);
149*993b0882SAndroid Build Coastguard Worker       token.append(cur_pattern);
150*993b0882SAndroid Build Coastguard Worker       i += (num_cur_patterns + 1) * num_chars;
151*993b0882SAndroid Build Coastguard Worker     } else {
152*993b0882SAndroid Build Coastguard Worker       token.append(char_tokens[i]);
153*993b0882SAndroid Build Coastguard Worker       ++i;
154*993b0882SAndroid Build Coastguard Worker     }
155*993b0882SAndroid Build Coastguard Worker   }
156*993b0882SAndroid Build Coastguard Worker 
157*993b0882SAndroid Build Coastguard Worker   return token;
158*993b0882SAndroid Build Coastguard Worker }
159*993b0882SAndroid Build Coastguard Worker 
160*993b0882SAndroid Build Coastguard Worker template <typename T>
SplitBySpaceInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)161*993b0882SAndroid Build Coastguard Worker void SplitBySpaceInternal(std::vector<T>* tokens, const char* input_ptr,
162*993b0882SAndroid Build Coastguard Worker                           size_t len, size_t max_input, size_t max_tokens) {
163*993b0882SAndroid Build Coastguard Worker   size_t last_index =
164*993b0882SAndroid Build Coastguard Worker       max_input == kEntireString ? len : (len < max_input ? len : max_input);
165*993b0882SAndroid Build Coastguard Worker   size_t start = 0;
166*993b0882SAndroid Build Coastguard Worker   // skip leading spaces
167*993b0882SAndroid Build Coastguard Worker   while (start < last_index && input_ptr[start] == kSpace) {
168*993b0882SAndroid Build Coastguard Worker     start++;
169*993b0882SAndroid Build Coastguard Worker   }
170*993b0882SAndroid Build Coastguard Worker   auto end = FindNextSpace(input_ptr, start, last_index);
171*993b0882SAndroid Build Coastguard Worker   while (end != kInvalid &&
172*993b0882SAndroid Build Coastguard Worker          (max_tokens == kAllTokens || tokens->size() < max_tokens - 1)) {
173*993b0882SAndroid Build Coastguard Worker     auto length = end - start;
174*993b0882SAndroid Build Coastguard Worker     if (length > 0) {
175*993b0882SAndroid Build Coastguard Worker       tokens->emplace_back(input_ptr + start, length);
176*993b0882SAndroid Build Coastguard Worker     }
177*993b0882SAndroid Build Coastguard Worker 
178*993b0882SAndroid Build Coastguard Worker     start = end + 1;
179*993b0882SAndroid Build Coastguard Worker     end = FindNextSpace(input_ptr, start, last_index);
180*993b0882SAndroid Build Coastguard Worker   }
181*993b0882SAndroid Build Coastguard Worker   auto length = end == kInvalid ? (last_index - start) : (end - start);
182*993b0882SAndroid Build Coastguard Worker   if (length > 0) {
183*993b0882SAndroid Build Coastguard Worker     tokens->emplace_back(input_ptr + start, length);
184*993b0882SAndroid Build Coastguard Worker   }
185*993b0882SAndroid Build Coastguard Worker }
186*993b0882SAndroid Build Coastguard Worker 
SplitBySpace(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)187*993b0882SAndroid Build Coastguard Worker std::vector<std::string> SplitBySpace(const char* input_ptr, size_t len,
188*993b0882SAndroid Build Coastguard Worker                                       size_t max_input, size_t max_tokens) {
189*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> tokens;
190*993b0882SAndroid Build Coastguard Worker   SplitBySpaceInternal(&tokens, input_ptr, len, max_input, max_tokens);
191*993b0882SAndroid Build Coastguard Worker   return tokens;
192*993b0882SAndroid Build Coastguard Worker }
193*993b0882SAndroid Build Coastguard Worker 
prepend_separator(char separator)194*993b0882SAndroid Build Coastguard Worker bool prepend_separator(char separator) { return separator == kApostrophe; }
195*993b0882SAndroid Build Coastguard Worker 
is_numeric(char c)196*993b0882SAndroid Build Coastguard Worker bool is_numeric(char c) { return c >= '0' && c <= '9'; }
197*993b0882SAndroid Build Coastguard Worker 
198*993b0882SAndroid Build Coastguard Worker class ProjectionNormalizer {
199*993b0882SAndroid Build Coastguard Worker  public:
ProjectionNormalizer(const std::string & separators,bool normalize_repetition=false)200*993b0882SAndroid Build Coastguard Worker   explicit ProjectionNormalizer(const std::string& separators,
201*993b0882SAndroid Build Coastguard Worker                                 bool normalize_repetition = false) {
202*993b0882SAndroid Build Coastguard Worker     InitializeSeparators(separators);
203*993b0882SAndroid Build Coastguard Worker     normalize_repetition_ = normalize_repetition;
204*993b0882SAndroid Build Coastguard Worker   }
205*993b0882SAndroid Build Coastguard Worker 
206*993b0882SAndroid Build Coastguard Worker   // Normalizes the repeated characters (except numbers) which consecutively
207*993b0882SAndroid Build Coastguard Worker   // appeared more than twice in a word.
Normalize(const std::string & input,size_t max_input=300)208*993b0882SAndroid Build Coastguard Worker   std::string Normalize(const std::string& input, size_t max_input = 300) {
209*993b0882SAndroid Build Coastguard Worker     return Normalize(input.data(), input.length(), max_input);
210*993b0882SAndroid Build Coastguard Worker   }
Normalize(const char * input_ptr,size_t len,size_t max_input=300)211*993b0882SAndroid Build Coastguard Worker   std::string Normalize(const char* input_ptr, size_t len,
212*993b0882SAndroid Build Coastguard Worker                         size_t max_input = 300) {
213*993b0882SAndroid Build Coastguard Worker     std::string normalized(input_ptr, std::min(len, max_input));
214*993b0882SAndroid Build Coastguard Worker 
215*993b0882SAndroid Build Coastguard Worker     if (normalize_repetition_) {
216*993b0882SAndroid Build Coastguard Worker       // Remove repeated 1 char (e.g. soooo => soo)
217*993b0882SAndroid Build Coastguard Worker       normalized = ContractToken(normalized.data(), normalized.length(), 1);
218*993b0882SAndroid Build Coastguard Worker 
219*993b0882SAndroid Build Coastguard Worker       // Remove repeated 2 chars from the beginning (e.g. hahaha =>
220*993b0882SAndroid Build Coastguard Worker       // haha, xhahaha => xhaha, xyhahaha => xyhaha).
221*993b0882SAndroid Build Coastguard Worker       normalized = ContractToken(normalized.data(), normalized.length(), 2);
222*993b0882SAndroid Build Coastguard Worker 
223*993b0882SAndroid Build Coastguard Worker       // Remove repeated 3 chars from the beginning
224*993b0882SAndroid Build Coastguard Worker       // (e.g. wowwowwow => wowwow, abcdbcdbcd => abcdbcd).
225*993b0882SAndroid Build Coastguard Worker       normalized = ContractToken(normalized.data(), normalized.length(), 3);
226*993b0882SAndroid Build Coastguard Worker     }
227*993b0882SAndroid Build Coastguard Worker 
228*993b0882SAndroid Build Coastguard Worker     if (!separators_.empty()) {
229*993b0882SAndroid Build Coastguard Worker       // Add space around separators_.
230*993b0882SAndroid Build Coastguard Worker       normalized = NormalizeInternal(normalized.data(), normalized.length());
231*993b0882SAndroid Build Coastguard Worker     }
232*993b0882SAndroid Build Coastguard Worker     return normalized;
233*993b0882SAndroid Build Coastguard Worker   }
234*993b0882SAndroid Build Coastguard Worker 
235*993b0882SAndroid Build Coastguard Worker  private:
236*993b0882SAndroid Build Coastguard Worker   // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)237*993b0882SAndroid Build Coastguard Worker   void InitializeSeparators(const std::string& separators) {
238*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < separators.length(); ++i) {
239*993b0882SAndroid Build Coastguard Worker       if (separators[i] != ' ') {
240*993b0882SAndroid Build Coastguard Worker         separators_.insert(separators[i]);
241*993b0882SAndroid Build Coastguard Worker       }
242*993b0882SAndroid Build Coastguard Worker     }
243*993b0882SAndroid Build Coastguard Worker   }
244*993b0882SAndroid Build Coastguard Worker 
245*993b0882SAndroid Build Coastguard Worker   // Removes repeated chars.
NormalizeInternal(const char * input_ptr,size_t len)246*993b0882SAndroid Build Coastguard Worker   std::string NormalizeInternal(const char* input_ptr, size_t len) {
247*993b0882SAndroid Build Coastguard Worker     std::string normalized;
248*993b0882SAndroid Build Coastguard Worker     normalized.reserve(len * 2);
249*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < len; ++i) {
250*993b0882SAndroid Build Coastguard Worker       char c = input_ptr[i];
251*993b0882SAndroid Build Coastguard Worker       bool matched_separator = separators_.find(c) != separators_.end();
252*993b0882SAndroid Build Coastguard Worker       if (matched_separator) {
253*993b0882SAndroid Build Coastguard Worker         if (i > 0 && input_ptr[i - 1] != ' ' && normalized.back() != ' ') {
254*993b0882SAndroid Build Coastguard Worker           normalized.append(" ");
255*993b0882SAndroid Build Coastguard Worker         }
256*993b0882SAndroid Build Coastguard Worker       }
257*993b0882SAndroid Build Coastguard Worker       normalized.append(1, c);
258*993b0882SAndroid Build Coastguard Worker       if (matched_separator) {
259*993b0882SAndroid Build Coastguard Worker         if (i + 1 < len && input_ptr[i + 1] != ' ' && c != '\'') {
260*993b0882SAndroid Build Coastguard Worker           normalized.append(" ");
261*993b0882SAndroid Build Coastguard Worker         }
262*993b0882SAndroid Build Coastguard Worker       }
263*993b0882SAndroid Build Coastguard Worker     }
264*993b0882SAndroid Build Coastguard Worker     return normalized;
265*993b0882SAndroid Build Coastguard Worker   }
266*993b0882SAndroid Build Coastguard Worker 
267*993b0882SAndroid Build Coastguard Worker   absl::flat_hash_set<char> separators_;
268*993b0882SAndroid Build Coastguard Worker   bool normalize_repetition_;
269*993b0882SAndroid Build Coastguard Worker };
270*993b0882SAndroid Build Coastguard Worker 
271*993b0882SAndroid Build Coastguard Worker class ProjectionTokenizer {
272*993b0882SAndroid Build Coastguard Worker  public:
ProjectionTokenizer(const std::string & separators)273*993b0882SAndroid Build Coastguard Worker   explicit ProjectionTokenizer(const std::string& separators) {
274*993b0882SAndroid Build Coastguard Worker     InitializeSeparators(separators);
275*993b0882SAndroid Build Coastguard Worker   }
276*993b0882SAndroid Build Coastguard Worker 
277*993b0882SAndroid Build Coastguard Worker   // Tokenizes the input by separators_. Limit to max_tokens, when it is not -1.
Tokenize(const std::string & input,size_t max_input,size_t max_tokens) const278*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> Tokenize(const std::string& input, size_t max_input,
279*993b0882SAndroid Build Coastguard Worker                                     size_t max_tokens) const {
280*993b0882SAndroid Build Coastguard Worker     return Tokenize(input.c_str(), input.size(), max_input, max_tokens);
281*993b0882SAndroid Build Coastguard Worker   }
282*993b0882SAndroid Build Coastguard Worker 
Tokenize(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens) const283*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> Tokenize(const char* input_ptr, size_t len,
284*993b0882SAndroid Build Coastguard Worker                                     size_t max_input, size_t max_tokens) const {
285*993b0882SAndroid Build Coastguard Worker     // If separators_ is not given, tokenize the input with a space.
286*993b0882SAndroid Build Coastguard Worker     if (separators_.empty()) {
287*993b0882SAndroid Build Coastguard Worker       return SplitBySpace(input_ptr, len, max_input, max_tokens);
288*993b0882SAndroid Build Coastguard Worker     }
289*993b0882SAndroid Build Coastguard Worker 
290*993b0882SAndroid Build Coastguard Worker     std::vector<std::string> tokens;
291*993b0882SAndroid Build Coastguard Worker     size_t last_index =
292*993b0882SAndroid Build Coastguard Worker         max_input == kEntireString ? len : (len < max_input ? len : max_input);
293*993b0882SAndroid Build Coastguard Worker     size_t start = 0;
294*993b0882SAndroid Build Coastguard Worker     // Skip leading spaces.
295*993b0882SAndroid Build Coastguard Worker     while (start < last_index && input_ptr[start] == kSpace) {
296*993b0882SAndroid Build Coastguard Worker       start++;
297*993b0882SAndroid Build Coastguard Worker     }
298*993b0882SAndroid Build Coastguard Worker     auto end = FindNextSeparator(input_ptr, start, last_index);
299*993b0882SAndroid Build Coastguard Worker 
300*993b0882SAndroid Build Coastguard Worker     while (end != kInvalid &&
301*993b0882SAndroid Build Coastguard Worker            (max_tokens == kAllTokens || tokens.size() < max_tokens - 1)) {
302*993b0882SAndroid Build Coastguard Worker       auto length = end - start;
303*993b0882SAndroid Build Coastguard Worker       if (length > 0) tokens.emplace_back(input_ptr + start, length);
304*993b0882SAndroid Build Coastguard Worker 
305*993b0882SAndroid Build Coastguard Worker       // Add the separator (except space and apostrophe) as a token
306*993b0882SAndroid Build Coastguard Worker       char separator = input_ptr[end];
307*993b0882SAndroid Build Coastguard Worker       if (separator != kSpace && separator != kApostrophe) {
308*993b0882SAndroid Build Coastguard Worker         tokens.emplace_back(input_ptr + end, 1);
309*993b0882SAndroid Build Coastguard Worker       }
310*993b0882SAndroid Build Coastguard Worker 
311*993b0882SAndroid Build Coastguard Worker       start = end + (prepend_separator(separator) ? 0 : 1);
312*993b0882SAndroid Build Coastguard Worker       end = FindNextSeparator(input_ptr, end + 1, last_index);
313*993b0882SAndroid Build Coastguard Worker     }
314*993b0882SAndroid Build Coastguard Worker     auto length = end == kInvalid ? (last_index - start) : (end - start);
315*993b0882SAndroid Build Coastguard Worker     if (length > 0) tokens.emplace_back(input_ptr + start, length);
316*993b0882SAndroid Build Coastguard Worker     return tokens;
317*993b0882SAndroid Build Coastguard Worker   }
318*993b0882SAndroid Build Coastguard Worker 
319*993b0882SAndroid Build Coastguard Worker  private:
320*993b0882SAndroid Build Coastguard Worker   // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)321*993b0882SAndroid Build Coastguard Worker   void InitializeSeparators(const std::string& separators) {
322*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < separators.length(); ++i) {
323*993b0882SAndroid Build Coastguard Worker       separators_.insert(separators[i]);
324*993b0882SAndroid Build Coastguard Worker     }
325*993b0882SAndroid Build Coastguard Worker   }
326*993b0882SAndroid Build Coastguard Worker 
327*993b0882SAndroid Build Coastguard Worker   // Starting from input_ptr[from], search for the next occurrence of
328*993b0882SAndroid Build Coastguard Worker   // separators_. Don't search beyond input_ptr[length](non-inclusive). Return
329*993b0882SAndroid Build Coastguard Worker   // -1 if not found.
FindNextSeparator(const char * input_ptr,size_t from,size_t length) const330*993b0882SAndroid Build Coastguard Worker   size_t FindNextSeparator(const char* input_ptr, size_t from,
331*993b0882SAndroid Build Coastguard Worker                            size_t length) const {
332*993b0882SAndroid Build Coastguard Worker     auto index = from;
333*993b0882SAndroid Build Coastguard Worker     while (index < length) {
334*993b0882SAndroid Build Coastguard Worker       char c = input_ptr[index];
335*993b0882SAndroid Build Coastguard Worker       // Do not break a number (e.g. "10,000", "0.23").
336*993b0882SAndroid Build Coastguard Worker       if (c == kComma || c == kDot) {
337*993b0882SAndroid Build Coastguard Worker         if (index + 1 < length && is_numeric(input_ptr[index + 1])) {
338*993b0882SAndroid Build Coastguard Worker           c = input_ptr[++index];
339*993b0882SAndroid Build Coastguard Worker         }
340*993b0882SAndroid Build Coastguard Worker       }
341*993b0882SAndroid Build Coastguard Worker       if (separators_.find(c) != separators_.end()) {
342*993b0882SAndroid Build Coastguard Worker         break;
343*993b0882SAndroid Build Coastguard Worker       }
344*993b0882SAndroid Build Coastguard Worker       ++index;
345*993b0882SAndroid Build Coastguard Worker     }
346*993b0882SAndroid Build Coastguard Worker     return index == length ? kInvalid : index;
347*993b0882SAndroid Build Coastguard Worker   }
348*993b0882SAndroid Build Coastguard Worker 
349*993b0882SAndroid Build Coastguard Worker   absl::flat_hash_set<char> separators_;
350*993b0882SAndroid Build Coastguard Worker };
351*993b0882SAndroid Build Coastguard Worker 
StripTrailingAsciiPunctuation(std::string * str)352*993b0882SAndroid Build Coastguard Worker inline void StripTrailingAsciiPunctuation(std::string* str) {
353*993b0882SAndroid Build Coastguard Worker   auto it = std::find_if_not(str->rbegin(), str->rend(), ::ispunct);
354*993b0882SAndroid Build Coastguard Worker   str->erase(str->rend() - it);
355*993b0882SAndroid Build Coastguard Worker }
356*993b0882SAndroid Build Coastguard Worker 
PreProcessString(const char * str,int len,const bool remove_punctuation)357*993b0882SAndroid Build Coastguard Worker std::string PreProcessString(const char* str, int len,
358*993b0882SAndroid Build Coastguard Worker                              const bool remove_punctuation) {
359*993b0882SAndroid Build Coastguard Worker   std::string output_str(str, len);
360*993b0882SAndroid Build Coastguard Worker   std::transform(output_str.begin(), output_str.end(), output_str.begin(),
361*993b0882SAndroid Build Coastguard Worker                  ::tolower);
362*993b0882SAndroid Build Coastguard Worker 
363*993b0882SAndroid Build Coastguard Worker   // Remove trailing punctuation.
364*993b0882SAndroid Build Coastguard Worker   if (remove_punctuation) {
365*993b0882SAndroid Build Coastguard Worker     StripTrailingAsciiPunctuation(&output_str);
366*993b0882SAndroid Build Coastguard Worker   }
367*993b0882SAndroid Build Coastguard Worker 
368*993b0882SAndroid Build Coastguard Worker   if (output_str.empty()) {
369*993b0882SAndroid Build Coastguard Worker     output_str.assign(str, len);
370*993b0882SAndroid Build Coastguard Worker   }
371*993b0882SAndroid Build Coastguard Worker   return output_str;
372*993b0882SAndroid Build Coastguard Worker }
373*993b0882SAndroid Build Coastguard Worker 
ShouldIncludeCurrentNgram(const SkipGramParams & params,int size)374*993b0882SAndroid Build Coastguard Worker bool ShouldIncludeCurrentNgram(const SkipGramParams& params, int size) {
375*993b0882SAndroid Build Coastguard Worker   if (size <= 0) {
376*993b0882SAndroid Build Coastguard Worker     return false;
377*993b0882SAndroid Build Coastguard Worker   }
378*993b0882SAndroid Build Coastguard Worker   if (params.include_all_ngrams) {
379*993b0882SAndroid Build Coastguard Worker     return size <= params.ngram_size;
380*993b0882SAndroid Build Coastguard Worker   } else {
381*993b0882SAndroid Build Coastguard Worker     return size == params.ngram_size;
382*993b0882SAndroid Build Coastguard Worker   }
383*993b0882SAndroid Build Coastguard Worker }
384*993b0882SAndroid Build Coastguard Worker 
ShouldStepInRecursion(const std::vector<int> & stack,int stack_idx,int num_words,const SkipGramParams & params)385*993b0882SAndroid Build Coastguard Worker bool ShouldStepInRecursion(const std::vector<int>& stack, int stack_idx,
386*993b0882SAndroid Build Coastguard Worker                            int num_words, const SkipGramParams& params) {
387*993b0882SAndroid Build Coastguard Worker   // If current stack size and next word enumeration are within valid range.
388*993b0882SAndroid Build Coastguard Worker   if (stack_idx < params.ngram_size && stack[stack_idx] + 1 < num_words) {
389*993b0882SAndroid Build Coastguard Worker     // If this stack is empty, step in for first word enumeration.
390*993b0882SAndroid Build Coastguard Worker     if (stack_idx == 0) {
391*993b0882SAndroid Build Coastguard Worker       return true;
392*993b0882SAndroid Build Coastguard Worker     }
393*993b0882SAndroid Build Coastguard Worker     // If next word enumeration are within the range of max_skip_size.
394*993b0882SAndroid Build Coastguard Worker     // NOTE: equivalent to
395*993b0882SAndroid Build Coastguard Worker     //   next_word_idx = stack[stack_idx] + 1
396*993b0882SAndroid Build Coastguard Worker     //   next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
397*993b0882SAndroid Build Coastguard Worker     if (stack[stack_idx] - stack[stack_idx - 1] <= params.max_skip_size) {
398*993b0882SAndroid Build Coastguard Worker       return true;
399*993b0882SAndroid Build Coastguard Worker     }
400*993b0882SAndroid Build Coastguard Worker   }
401*993b0882SAndroid Build Coastguard Worker   return false;
402*993b0882SAndroid Build Coastguard Worker }
403*993b0882SAndroid Build Coastguard Worker 
JoinTokensBySpace(const std::vector<int> & stack,int stack_idx,const std::vector<std::string> & tokens)404*993b0882SAndroid Build Coastguard Worker std::string JoinTokensBySpace(const std::vector<int>& stack, int stack_idx,
405*993b0882SAndroid Build Coastguard Worker                               const std::vector<std::string>& tokens) {
406*993b0882SAndroid Build Coastguard Worker   int len = 0;
407*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < stack_idx; i++) {
408*993b0882SAndroid Build Coastguard Worker     len += tokens[stack[i]].size();
409*993b0882SAndroid Build Coastguard Worker   }
410*993b0882SAndroid Build Coastguard Worker   len += stack_idx - 1;
411*993b0882SAndroid Build Coastguard Worker 
412*993b0882SAndroid Build Coastguard Worker   std::string res;
413*993b0882SAndroid Build Coastguard Worker   res.reserve(len);
414*993b0882SAndroid Build Coastguard Worker   res.append(tokens[stack[0]]);
415*993b0882SAndroid Build Coastguard Worker   for (int i = 1; i < stack_idx; i++) {
416*993b0882SAndroid Build Coastguard Worker     res.append(" ");
417*993b0882SAndroid Build Coastguard Worker     res.append(tokens[stack[i]]);
418*993b0882SAndroid Build Coastguard Worker   }
419*993b0882SAndroid Build Coastguard Worker 
420*993b0882SAndroid Build Coastguard Worker   return res;
421*993b0882SAndroid Build Coastguard Worker }
422*993b0882SAndroid Build Coastguard Worker 
ExtractSkipGramsImpl(const std::vector<std::string> & tokens,const SkipGramParams & params)423*993b0882SAndroid Build Coastguard Worker std::unordered_map<std::string, int> ExtractSkipGramsImpl(
424*993b0882SAndroid Build Coastguard Worker     const std::vector<std::string>& tokens, const SkipGramParams& params) {
425*993b0882SAndroid Build Coastguard Worker   // Ignore positional tokens.
426*993b0882SAndroid Build Coastguard Worker   static auto* blacklist = new std::unordered_set<std::string>({
427*993b0882SAndroid Build Coastguard Worker       kStartToken,
428*993b0882SAndroid Build Coastguard Worker       kEndToken,
429*993b0882SAndroid Build Coastguard Worker       kEmptyToken,
430*993b0882SAndroid Build Coastguard Worker   });
431*993b0882SAndroid Build Coastguard Worker 
432*993b0882SAndroid Build Coastguard Worker   std::unordered_map<std::string, int> res;
433*993b0882SAndroid Build Coastguard Worker 
434*993b0882SAndroid Build Coastguard Worker   // Stack stores the index of word used to generate ngram.
435*993b0882SAndroid Build Coastguard Worker   // The size of stack is the size of ngram.
436*993b0882SAndroid Build Coastguard Worker   std::vector<int> stack(params.ngram_size + 1, 0);
437*993b0882SAndroid Build Coastguard Worker   // Stack index that indicates which depth the recursion is operating at.
438*993b0882SAndroid Build Coastguard Worker   int stack_idx = 1;
439*993b0882SAndroid Build Coastguard Worker   int num_words = tokens.size();
440*993b0882SAndroid Build Coastguard Worker 
441*993b0882SAndroid Build Coastguard Worker   while (stack_idx >= 0) {
442*993b0882SAndroid Build Coastguard Worker     if (ShouldStepInRecursion(stack, stack_idx, num_words, params)) {
443*993b0882SAndroid Build Coastguard Worker       // When current depth can fill with a new word
444*993b0882SAndroid Build Coastguard Worker       // and the new word is within the max range to skip,
445*993b0882SAndroid Build Coastguard Worker       // fill this word to stack, recurse into next depth.
446*993b0882SAndroid Build Coastguard Worker       stack[stack_idx]++;
447*993b0882SAndroid Build Coastguard Worker       stack_idx++;
448*993b0882SAndroid Build Coastguard Worker       stack[stack_idx] = stack[stack_idx - 1];
449*993b0882SAndroid Build Coastguard Worker     } else {
450*993b0882SAndroid Build Coastguard Worker       if (ShouldIncludeCurrentNgram(params, stack_idx)) {
451*993b0882SAndroid Build Coastguard Worker         // Add n-gram to tensor buffer when the stack has filled with enough
452*993b0882SAndroid Build Coastguard Worker         // words to generate the ngram.
453*993b0882SAndroid Build Coastguard Worker         std::string ngram = JoinTokensBySpace(stack, stack_idx, tokens);
454*993b0882SAndroid Build Coastguard Worker         if (blacklist->find(ngram) == blacklist->end()) {
455*993b0882SAndroid Build Coastguard Worker           res[ngram] = stack_idx;
456*993b0882SAndroid Build Coastguard Worker         }
457*993b0882SAndroid Build Coastguard Worker       }
458*993b0882SAndroid Build Coastguard Worker       // When current depth cannot fill with a valid new word,
459*993b0882SAndroid Build Coastguard Worker       // and not in last depth to generate ngram,
460*993b0882SAndroid Build Coastguard Worker       // step back to previous depth to iterate to next possible word.
461*993b0882SAndroid Build Coastguard Worker       stack_idx--;
462*993b0882SAndroid Build Coastguard Worker     }
463*993b0882SAndroid Build Coastguard Worker   }
464*993b0882SAndroid Build Coastguard Worker 
465*993b0882SAndroid Build Coastguard Worker   return res;
466*993b0882SAndroid Build Coastguard Worker }
467*993b0882SAndroid Build Coastguard Worker 
ExtractSkipGrams(const std::string & input,ProjectionTokenizer * tokenizer,ProjectionNormalizer * normalizer,const SkipGramParams & params)468*993b0882SAndroid Build Coastguard Worker std::unordered_map<std::string, int> ExtractSkipGrams(
469*993b0882SAndroid Build Coastguard Worker     const std::string& input, ProjectionTokenizer* tokenizer,
470*993b0882SAndroid Build Coastguard Worker     ProjectionNormalizer* normalizer, const SkipGramParams& params) {
471*993b0882SAndroid Build Coastguard Worker   // Normalize the input.
472*993b0882SAndroid Build Coastguard Worker   const std::string& normalized =
473*993b0882SAndroid Build Coastguard Worker       normalizer == nullptr
474*993b0882SAndroid Build Coastguard Worker           ? input
475*993b0882SAndroid Build Coastguard Worker           : normalizer->Normalize(input, params.max_input_chars);
476*993b0882SAndroid Build Coastguard Worker 
477*993b0882SAndroid Build Coastguard Worker   // Split sentence to words.
478*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> tokens;
479*993b0882SAndroid Build Coastguard Worker   if (params.char_level) {
480*993b0882SAndroid Build Coastguard Worker     tokens = SplitByChar(normalized.data(), normalized.size(),
481*993b0882SAndroid Build Coastguard Worker                          params.max_input_chars);
482*993b0882SAndroid Build Coastguard Worker   } else {
483*993b0882SAndroid Build Coastguard Worker     tokens = tokenizer->Tokenize(normalized.data(), normalized.size(),
484*993b0882SAndroid Build Coastguard Worker                                  params.max_input_chars, kAllTokens);
485*993b0882SAndroid Build Coastguard Worker   }
486*993b0882SAndroid Build Coastguard Worker 
487*993b0882SAndroid Build Coastguard Worker   // Process tokens
488*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < tokens.size(); ++i) {
489*993b0882SAndroid Build Coastguard Worker     if (params.preprocess) {
490*993b0882SAndroid Build Coastguard Worker       tokens[i] = PreProcessString(tokens[i].data(), tokens[i].size(),
491*993b0882SAndroid Build Coastguard Worker                                    params.remove_punctuation);
492*993b0882SAndroid Build Coastguard Worker     }
493*993b0882SAndroid Build Coastguard Worker   }
494*993b0882SAndroid Build Coastguard Worker 
495*993b0882SAndroid Build Coastguard Worker   tokens.insert(tokens.begin(), kStartToken);
496*993b0882SAndroid Build Coastguard Worker   tokens.insert(tokens.end(), kEndToken);
497*993b0882SAndroid Build Coastguard Worker 
498*993b0882SAndroid Build Coastguard Worker   return ExtractSkipGramsImpl(tokens, params);
499*993b0882SAndroid Build Coastguard Worker }
500*993b0882SAndroid Build Coastguard Worker }  // namespace
501*993b0882SAndroid Build Coastguard Worker // Generates LSH projections for input strings.  This uses the framework in
502*993b0882SAndroid Build Coastguard Worker // `string_projection_base.h`, with the implementation details that the input is
503*993b0882SAndroid Build Coastguard Worker // a string tensor of messages and the op will perform tokenization.
504*993b0882SAndroid Build Coastguard Worker //
505*993b0882SAndroid Build Coastguard Worker // Input:
506*993b0882SAndroid Build Coastguard Worker //   tensor[0]: Input message, string[...]
507*993b0882SAndroid Build Coastguard Worker //
508*993b0882SAndroid Build Coastguard Worker // Additional attributes:
509*993b0882SAndroid Build Coastguard Worker //   max_input_chars: int[]
510*993b0882SAndroid Build Coastguard Worker //     maximum number of input characters to use from each message.
511*993b0882SAndroid Build Coastguard Worker //   token_separators: string[]
512*993b0882SAndroid Build Coastguard Worker //     the list of separators used to tokenize the input.
513*993b0882SAndroid Build Coastguard Worker //   normalize_repetition: bool[]
514*993b0882SAndroid Build Coastguard Worker //     if true, remove repeated characters in tokens ('loool' -> 'lol').
515*993b0882SAndroid Build Coastguard Worker 
516*993b0882SAndroid Build Coastguard Worker static const int kInputMessage = 0;
517*993b0882SAndroid Build Coastguard Worker 
518*993b0882SAndroid Build Coastguard Worker class StringProjectionOp : public StringProjectionOpBase {
519*993b0882SAndroid Build Coastguard Worker  public:
StringProjectionOp(const flexbuffers::Map & custom_options)520*993b0882SAndroid Build Coastguard Worker   explicit StringProjectionOp(const flexbuffers::Map& custom_options)
521*993b0882SAndroid Build Coastguard Worker       : StringProjectionOpBase(custom_options),
522*993b0882SAndroid Build Coastguard Worker         projection_normalizer_(
523*993b0882SAndroid Build Coastguard Worker             custom_options["token_separators"].AsString().str(),
524*993b0882SAndroid Build Coastguard Worker             custom_options["normalize_repetition"].AsBool()),
525*993b0882SAndroid Build Coastguard Worker         projection_tokenizer_(" ") {
526*993b0882SAndroid Build Coastguard Worker     if (custom_options["max_input_chars"].IsInt()) {
527*993b0882SAndroid Build Coastguard Worker       skip_gram_params().max_input_chars =
528*993b0882SAndroid Build Coastguard Worker           custom_options["max_input_chars"].AsInt32();
529*993b0882SAndroid Build Coastguard Worker     }
530*993b0882SAndroid Build Coastguard Worker   }
531*993b0882SAndroid Build Coastguard Worker 
InitializeInput(TfLiteContext * context,TfLiteNode * node)532*993b0882SAndroid Build Coastguard Worker   TfLiteStatus InitializeInput(TfLiteContext* context,
533*993b0882SAndroid Build Coastguard Worker                                TfLiteNode* node) override {
534*993b0882SAndroid Build Coastguard Worker     input_ = &context->tensors[node->inputs->data[kInputMessage]];
535*993b0882SAndroid Build Coastguard Worker     return kTfLiteOk;
536*993b0882SAndroid Build Coastguard Worker   }
537*993b0882SAndroid Build Coastguard Worker 
ExtractSkipGrams(int i)538*993b0882SAndroid Build Coastguard Worker   std::unordered_map<std::string, int> ExtractSkipGrams(int i) override {
539*993b0882SAndroid Build Coastguard Worker     StringRef input = GetString(input_, i);
540*993b0882SAndroid Build Coastguard Worker     return ::tflite::ops::custom::libtextclassifier3::string_projection::
541*993b0882SAndroid Build Coastguard Worker         ExtractSkipGrams({input.str, static_cast<size_t>(input.len)},
542*993b0882SAndroid Build Coastguard Worker                          &projection_tokenizer_, &projection_normalizer_,
543*993b0882SAndroid Build Coastguard Worker                          skip_gram_params());
544*993b0882SAndroid Build Coastguard Worker   }
545*993b0882SAndroid Build Coastguard Worker 
FinalizeInput()546*993b0882SAndroid Build Coastguard Worker   void FinalizeInput() override { input_ = nullptr; }
547*993b0882SAndroid Build Coastguard Worker 
GetInputShape(TfLiteContext * context,TfLiteNode * node)548*993b0882SAndroid Build Coastguard Worker   TfLiteIntArray* GetInputShape(TfLiteContext* context,
549*993b0882SAndroid Build Coastguard Worker                                 TfLiteNode* node) override {
550*993b0882SAndroid Build Coastguard Worker     return context->tensors[node->inputs->data[kInputMessage]].dims;
551*993b0882SAndroid Build Coastguard Worker   }
552*993b0882SAndroid Build Coastguard Worker 
553*993b0882SAndroid Build Coastguard Worker  private:
554*993b0882SAndroid Build Coastguard Worker   ProjectionNormalizer projection_normalizer_;
555*993b0882SAndroid Build Coastguard Worker   ProjectionTokenizer projection_tokenizer_;
556*993b0882SAndroid Build Coastguard Worker 
557*993b0882SAndroid Build Coastguard Worker   TfLiteTensor* input_;
558*993b0882SAndroid Build Coastguard Worker };
559*993b0882SAndroid Build Coastguard Worker 
Init(TfLiteContext * context,const char * buffer,size_t length)560*993b0882SAndroid Build Coastguard Worker void* Init(TfLiteContext* context, const char* buffer, size_t length) {
561*993b0882SAndroid Build Coastguard Worker   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
562*993b0882SAndroid Build Coastguard Worker   return new StringProjectionOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
563*993b0882SAndroid Build Coastguard Worker }
564*993b0882SAndroid Build Coastguard Worker 
565*993b0882SAndroid Build Coastguard Worker }  // namespace string_projection
566*993b0882SAndroid Build Coastguard Worker 
567*993b0882SAndroid Build Coastguard Worker // This op converts a list of strings to integers via LSH projections.
Register_STRING_PROJECTION()568*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_STRING_PROJECTION() {
569*993b0882SAndroid Build Coastguard Worker   static TfLiteRegistration r = {libtextclassifier3::string_projection::Init,
570*993b0882SAndroid Build Coastguard Worker                                  libtextclassifier3::string_projection::Free,
571*993b0882SAndroid Build Coastguard Worker                                  libtextclassifier3::string_projection::Resize,
572*993b0882SAndroid Build Coastguard Worker                                  libtextclassifier3::string_projection::Eval};
573*993b0882SAndroid Build Coastguard Worker   return &r;
574*993b0882SAndroid Build Coastguard Worker }
575*993b0882SAndroid Build Coastguard Worker 
576*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
577*993b0882SAndroid Build Coastguard Worker }  // namespace custom
578*993b0882SAndroid Build Coastguard Worker }  // namespace ops
579*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
580