xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/string_projection.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/string_projection.h"
18 
19 #include <string>
20 #include <unordered_map>
21 
22 #include "utils/strings/utf8.h"
23 #include "utils/tflite/string_projection_base.h"
24 #include "utils/utf8/unilib-common.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "flatbuffers/flexbuffers.h"
28 #include "tensorflow/lite/context.h"
29 #include "tensorflow/lite/string_util.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace custom {
34 
35 namespace libtextclassifier3 {
36 namespace string_projection {
37 namespace {
38 
39 const char kStartToken[] = "<S>";
40 const char kEndToken[] = "<E>";
41 const char kEmptyToken[] = "<S> <E>";
42 constexpr size_t kEntireString = SIZE_MAX;
43 constexpr size_t kAllTokens = SIZE_MAX;
44 constexpr int kInvalid = -1;
45 
46 constexpr char kApostrophe = '\'';
47 constexpr char kSpace = ' ';
48 constexpr char kComma = ',';
49 constexpr char kDot = '.';
50 
51 // Returns true if the given text contains a number.
IsDigitString(const std::string & text)52 bool IsDigitString(const std::string& text) {
53   for (size_t i = 0; i < text.length();) {
54     const int bytes_read =
55         ::libtextclassifier3::GetNumBytesForUTF8Char(text.data());
56     if (bytes_read <= 0 || bytes_read > text.length() - i) {
57       break;
58     }
59     const char32_t rune = ::libtextclassifier3::ValidCharToRune(text.data());
60     if (::libtextclassifier3::IsDigit(rune)) return true;
61     i += bytes_read;
62   }
63   return false;
64 }
65 
66 // Gets the string containing |num_chars| characters from |start| position.
GetCharToken(const std::vector<std::string> & char_tokens,int start,int num_chars)67 std::string GetCharToken(const std::vector<std::string>& char_tokens, int start,
68                          int num_chars) {
69   std::string char_token = "";
70   if (start + num_chars <= char_tokens.size()) {
71     for (int i = 0; i < num_chars; ++i) {
72       char_token.append(char_tokens[start + i]);
73     }
74   }
75   return char_token;
76 }
77 
78 // Counts how many times |pattern| appeared from |start| position.
GetNumPattern(const std::vector<std::string> & char_tokens,size_t start,size_t num_chars,const std::string & pattern)79 int GetNumPattern(const std::vector<std::string>& char_tokens, size_t start,
80                   size_t num_chars, const std::string& pattern) {
81   int count = 0;
82   for (int i = start; i < char_tokens.size(); i += num_chars) {
83     std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
84     if (pattern == cur_pattern) {
85       ++count;
86     } else {
87       break;
88     }
89   }
90   return count;
91 }
92 
FindNextSpace(const char * input_ptr,size_t from,size_t length)93 inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
94   size_t space_index;
95   for (space_index = from; space_index < length; space_index++) {
96     if (input_ptr[space_index] == kSpace) {
97       break;
98     }
99   }
100   return space_index == length ? kInvalid : space_index;
101 }
102 
103 template <typename T>
SplitByCharInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_tokens)104 void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
105                          size_t len, size_t max_tokens) {
106   for (size_t i = 0; i < len;) {
107     auto bytes_read =
108         ::libtextclassifier3::GetNumBytesForUTF8Char(input_ptr + i);
109     if (bytes_read <= 0 || bytes_read > len - i) break;
110     tokens->emplace_back(input_ptr + i, bytes_read);
111     if (max_tokens != kInvalid && tokens->size() == max_tokens) {
112       break;
113     }
114     i += bytes_read;
115   }
116 }
117 
SplitByChar(const char * input_ptr,size_t len,size_t max_tokens)118 std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
119                                      size_t max_tokens) {
120   std::vector<std::string> tokens;
121   SplitByCharInternal(&tokens, input_ptr, len, max_tokens);
122   return tokens;
123 }
124 
ContractToken(const char * input_ptr,size_t len,size_t num_chars)125 std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
126   // This function contracts patterns whose length is |num_chars| and appeared
127   // more than twice. So if the input is shorter than 3 * |num_chars|, do not
128   // apply any contraction.
129   if (len < 3 * num_chars) {
130     return input_ptr;
131   }
132   std::vector<std::string> char_tokens = SplitByChar(input_ptr, len, len);
133 
134   std::string token;
135   token.reserve(len);
136   for (int i = 0; i < char_tokens.size();) {
137     std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
138 
139     // Count how many times this pattern appeared.
140     int num_cur_patterns = 0;
141     if (!absl::StrContains(cur_pattern, " ") && !IsDigitString(cur_pattern)) {
142       num_cur_patterns =
143           GetNumPattern(char_tokens, i + num_chars, num_chars, cur_pattern);
144     }
145 
146     if (num_cur_patterns >= 2) {
147       // If this pattern is repeated, store it only twice.
148       token.append(cur_pattern);
149       token.append(cur_pattern);
150       i += (num_cur_patterns + 1) * num_chars;
151     } else {
152       token.append(char_tokens[i]);
153       ++i;
154     }
155   }
156 
157   return token;
158 }
159 
160 template <typename T>
SplitBySpaceInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)161 void SplitBySpaceInternal(std::vector<T>* tokens, const char* input_ptr,
162                           size_t len, size_t max_input, size_t max_tokens) {
163   size_t last_index =
164       max_input == kEntireString ? len : (len < max_input ? len : max_input);
165   size_t start = 0;
166   // skip leading spaces
167   while (start < last_index && input_ptr[start] == kSpace) {
168     start++;
169   }
170   auto end = FindNextSpace(input_ptr, start, last_index);
171   while (end != kInvalid &&
172          (max_tokens == kAllTokens || tokens->size() < max_tokens - 1)) {
173     auto length = end - start;
174     if (length > 0) {
175       tokens->emplace_back(input_ptr + start, length);
176     }
177 
178     start = end + 1;
179     end = FindNextSpace(input_ptr, start, last_index);
180   }
181   auto length = end == kInvalid ? (last_index - start) : (end - start);
182   if (length > 0) {
183     tokens->emplace_back(input_ptr + start, length);
184   }
185 }
186 
SplitBySpace(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)187 std::vector<std::string> SplitBySpace(const char* input_ptr, size_t len,
188                                       size_t max_input, size_t max_tokens) {
189   std::vector<std::string> tokens;
190   SplitBySpaceInternal(&tokens, input_ptr, len, max_input, max_tokens);
191   return tokens;
192 }
193 
prepend_separator(char separator)194 bool prepend_separator(char separator) { return separator == kApostrophe; }
195 
is_numeric(char c)196 bool is_numeric(char c) { return c >= '0' && c <= '9'; }
197 
198 class ProjectionNormalizer {
199  public:
ProjectionNormalizer(const std::string & separators,bool normalize_repetition=false)200   explicit ProjectionNormalizer(const std::string& separators,
201                                 bool normalize_repetition = false) {
202     InitializeSeparators(separators);
203     normalize_repetition_ = normalize_repetition;
204   }
205 
206   // Normalizes the repeated characters (except numbers) which consecutively
207   // appeared more than twice in a word.
Normalize(const std::string & input,size_t max_input=300)208   std::string Normalize(const std::string& input, size_t max_input = 300) {
209     return Normalize(input.data(), input.length(), max_input);
210   }
Normalize(const char * input_ptr,size_t len,size_t max_input=300)211   std::string Normalize(const char* input_ptr, size_t len,
212                         size_t max_input = 300) {
213     std::string normalized(input_ptr, std::min(len, max_input));
214 
215     if (normalize_repetition_) {
216       // Remove repeated 1 char (e.g. soooo => soo)
217       normalized = ContractToken(normalized.data(), normalized.length(), 1);
218 
219       // Remove repeated 2 chars from the beginning (e.g. hahaha =>
220       // haha, xhahaha => xhaha, xyhahaha => xyhaha).
221       normalized = ContractToken(normalized.data(), normalized.length(), 2);
222 
223       // Remove repeated 3 chars from the beginning
224       // (e.g. wowwowwow => wowwow, abcdbcdbcd => abcdbcd).
225       normalized = ContractToken(normalized.data(), normalized.length(), 3);
226     }
227 
228     if (!separators_.empty()) {
229       // Add space around separators_.
230       normalized = NormalizeInternal(normalized.data(), normalized.length());
231     }
232     return normalized;
233   }
234 
235  private:
236   // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)237   void InitializeSeparators(const std::string& separators) {
238     for (int i = 0; i < separators.length(); ++i) {
239       if (separators[i] != ' ') {
240         separators_.insert(separators[i]);
241       }
242     }
243   }
244 
245   // Removes repeated chars.
NormalizeInternal(const char * input_ptr,size_t len)246   std::string NormalizeInternal(const char* input_ptr, size_t len) {
247     std::string normalized;
248     normalized.reserve(len * 2);
249     for (int i = 0; i < len; ++i) {
250       char c = input_ptr[i];
251       bool matched_separator = separators_.find(c) != separators_.end();
252       if (matched_separator) {
253         if (i > 0 && input_ptr[i - 1] != ' ' && normalized.back() != ' ') {
254           normalized.append(" ");
255         }
256       }
257       normalized.append(1, c);
258       if (matched_separator) {
259         if (i + 1 < len && input_ptr[i + 1] != ' ' && c != '\'') {
260           normalized.append(" ");
261         }
262       }
263     }
264     return normalized;
265   }
266 
267   absl::flat_hash_set<char> separators_;
268   bool normalize_repetition_;
269 };
270 
271 class ProjectionTokenizer {
272  public:
ProjectionTokenizer(const std::string & separators)273   explicit ProjectionTokenizer(const std::string& separators) {
274     InitializeSeparators(separators);
275   }
276 
277   // Tokenizes the input by separators_. Limit to max_tokens, when it is not -1.
Tokenize(const std::string & input,size_t max_input,size_t max_tokens) const278   std::vector<std::string> Tokenize(const std::string& input, size_t max_input,
279                                     size_t max_tokens) const {
280     return Tokenize(input.c_str(), input.size(), max_input, max_tokens);
281   }
282 
Tokenize(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens) const283   std::vector<std::string> Tokenize(const char* input_ptr, size_t len,
284                                     size_t max_input, size_t max_tokens) const {
285     // If separators_ is not given, tokenize the input with a space.
286     if (separators_.empty()) {
287       return SplitBySpace(input_ptr, len, max_input, max_tokens);
288     }
289 
290     std::vector<std::string> tokens;
291     size_t last_index =
292         max_input == kEntireString ? len : (len < max_input ? len : max_input);
293     size_t start = 0;
294     // Skip leading spaces.
295     while (start < last_index && input_ptr[start] == kSpace) {
296       start++;
297     }
298     auto end = FindNextSeparator(input_ptr, start, last_index);
299 
300     while (end != kInvalid &&
301            (max_tokens == kAllTokens || tokens.size() < max_tokens - 1)) {
302       auto length = end - start;
303       if (length > 0) tokens.emplace_back(input_ptr + start, length);
304 
305       // Add the separator (except space and apostrophe) as a token
306       char separator = input_ptr[end];
307       if (separator != kSpace && separator != kApostrophe) {
308         tokens.emplace_back(input_ptr + end, 1);
309       }
310 
311       start = end + (prepend_separator(separator) ? 0 : 1);
312       end = FindNextSeparator(input_ptr, end + 1, last_index);
313     }
314     auto length = end == kInvalid ? (last_index - start) : (end - start);
315     if (length > 0) tokens.emplace_back(input_ptr + start, length);
316     return tokens;
317   }
318 
319  private:
320   // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)321   void InitializeSeparators(const std::string& separators) {
322     for (int i = 0; i < separators.length(); ++i) {
323       separators_.insert(separators[i]);
324     }
325   }
326 
327   // Starting from input_ptr[from], search for the next occurrence of
328   // separators_. Don't search beyond input_ptr[length](non-inclusive). Return
329   // -1 if not found.
FindNextSeparator(const char * input_ptr,size_t from,size_t length) const330   size_t FindNextSeparator(const char* input_ptr, size_t from,
331                            size_t length) const {
332     auto index = from;
333     while (index < length) {
334       char c = input_ptr[index];
335       // Do not break a number (e.g. "10,000", "0.23").
336       if (c == kComma || c == kDot) {
337         if (index + 1 < length && is_numeric(input_ptr[index + 1])) {
338           c = input_ptr[++index];
339         }
340       }
341       if (separators_.find(c) != separators_.end()) {
342         break;
343       }
344       ++index;
345     }
346     return index == length ? kInvalid : index;
347   }
348 
349   absl::flat_hash_set<char> separators_;
350 };
351 
StripTrailingAsciiPunctuation(std::string * str)352 inline void StripTrailingAsciiPunctuation(std::string* str) {
353   auto it = std::find_if_not(str->rbegin(), str->rend(), ::ispunct);
354   str->erase(str->rend() - it);
355 }
356 
PreProcessString(const char * str,int len,const bool remove_punctuation)357 std::string PreProcessString(const char* str, int len,
358                              const bool remove_punctuation) {
359   std::string output_str(str, len);
360   std::transform(output_str.begin(), output_str.end(), output_str.begin(),
361                  ::tolower);
362 
363   // Remove trailing punctuation.
364   if (remove_punctuation) {
365     StripTrailingAsciiPunctuation(&output_str);
366   }
367 
368   if (output_str.empty()) {
369     output_str.assign(str, len);
370   }
371   return output_str;
372 }
373 
ShouldIncludeCurrentNgram(const SkipGramParams & params,int size)374 bool ShouldIncludeCurrentNgram(const SkipGramParams& params, int size) {
375   if (size <= 0) {
376     return false;
377   }
378   if (params.include_all_ngrams) {
379     return size <= params.ngram_size;
380   } else {
381     return size == params.ngram_size;
382   }
383 }
384 
ShouldStepInRecursion(const std::vector<int> & stack,int stack_idx,int num_words,const SkipGramParams & params)385 bool ShouldStepInRecursion(const std::vector<int>& stack, int stack_idx,
386                            int num_words, const SkipGramParams& params) {
387   // If current stack size and next word enumeration are within valid range.
388   if (stack_idx < params.ngram_size && stack[stack_idx] + 1 < num_words) {
389     // If this stack is empty, step in for first word enumeration.
390     if (stack_idx == 0) {
391       return true;
392     }
393     // If next word enumeration are within the range of max_skip_size.
394     // NOTE: equivalent to
395     //   next_word_idx = stack[stack_idx] + 1
396     //   next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
397     if (stack[stack_idx] - stack[stack_idx - 1] <= params.max_skip_size) {
398       return true;
399     }
400   }
401   return false;
402 }
403 
JoinTokensBySpace(const std::vector<int> & stack,int stack_idx,const std::vector<std::string> & tokens)404 std::string JoinTokensBySpace(const std::vector<int>& stack, int stack_idx,
405                               const std::vector<std::string>& tokens) {
406   int len = 0;
407   for (int i = 0; i < stack_idx; i++) {
408     len += tokens[stack[i]].size();
409   }
410   len += stack_idx - 1;
411 
412   std::string res;
413   res.reserve(len);
414   res.append(tokens[stack[0]]);
415   for (int i = 1; i < stack_idx; i++) {
416     res.append(" ");
417     res.append(tokens[stack[i]]);
418   }
419 
420   return res;
421 }
422 
ExtractSkipGramsImpl(const std::vector<std::string> & tokens,const SkipGramParams & params)423 std::unordered_map<std::string, int> ExtractSkipGramsImpl(
424     const std::vector<std::string>& tokens, const SkipGramParams& params) {
425   // Ignore positional tokens.
426   static auto* blacklist = new std::unordered_set<std::string>({
427       kStartToken,
428       kEndToken,
429       kEmptyToken,
430   });
431 
432   std::unordered_map<std::string, int> res;
433 
434   // Stack stores the index of word used to generate ngram.
435   // The size of stack is the size of ngram.
436   std::vector<int> stack(params.ngram_size + 1, 0);
437   // Stack index that indicates which depth the recursion is operating at.
438   int stack_idx = 1;
439   int num_words = tokens.size();
440 
441   while (stack_idx >= 0) {
442     if (ShouldStepInRecursion(stack, stack_idx, num_words, params)) {
443       // When current depth can fill with a new word
444       // and the new word is within the max range to skip,
445       // fill this word to stack, recurse into next depth.
446       stack[stack_idx]++;
447       stack_idx++;
448       stack[stack_idx] = stack[stack_idx - 1];
449     } else {
450       if (ShouldIncludeCurrentNgram(params, stack_idx)) {
451         // Add n-gram to tensor buffer when the stack has filled with enough
452         // words to generate the ngram.
453         std::string ngram = JoinTokensBySpace(stack, stack_idx, tokens);
454         if (blacklist->find(ngram) == blacklist->end()) {
455           res[ngram] = stack_idx;
456         }
457       }
458       // When current depth cannot fill with a valid new word,
459       // and not in last depth to generate ngram,
460       // step back to previous depth to iterate to next possible word.
461       stack_idx--;
462     }
463   }
464 
465   return res;
466 }
467 
ExtractSkipGrams(const std::string & input,ProjectionTokenizer * tokenizer,ProjectionNormalizer * normalizer,const SkipGramParams & params)468 std::unordered_map<std::string, int> ExtractSkipGrams(
469     const std::string& input, ProjectionTokenizer* tokenizer,
470     ProjectionNormalizer* normalizer, const SkipGramParams& params) {
471   // Normalize the input.
472   const std::string& normalized =
473       normalizer == nullptr
474           ? input
475           : normalizer->Normalize(input, params.max_input_chars);
476 
477   // Split sentence to words.
478   std::vector<std::string> tokens;
479   if (params.char_level) {
480     tokens = SplitByChar(normalized.data(), normalized.size(),
481                          params.max_input_chars);
482   } else {
483     tokens = tokenizer->Tokenize(normalized.data(), normalized.size(),
484                                  params.max_input_chars, kAllTokens);
485   }
486 
487   // Process tokens
488   for (int i = 0; i < tokens.size(); ++i) {
489     if (params.preprocess) {
490       tokens[i] = PreProcessString(tokens[i].data(), tokens[i].size(),
491                                    params.remove_punctuation);
492     }
493   }
494 
495   tokens.insert(tokens.begin(), kStartToken);
496   tokens.insert(tokens.end(), kEndToken);
497 
498   return ExtractSkipGramsImpl(tokens, params);
499 }
500 }  // namespace
501 // Generates LSH projections for input strings.  This uses the framework in
502 // `string_projection_base.h`, with the implementation details that the input is
503 // a string tensor of messages and the op will perform tokenization.
504 //
505 // Input:
506 //   tensor[0]: Input message, string[...]
507 //
508 // Additional attributes:
509 //   max_input_chars: int[]
510 //     maximum number of input characters to use from each message.
511 //   token_separators: string[]
512 //     the list of separators used to tokenize the input.
513 //   normalize_repetition: bool[]
514 //     if true, remove repeated characters in tokens ('loool' -> 'lol').
515 
516 static const int kInputMessage = 0;
517 
518 class StringProjectionOp : public StringProjectionOpBase {
519  public:
StringProjectionOp(const flexbuffers::Map & custom_options)520   explicit StringProjectionOp(const flexbuffers::Map& custom_options)
521       : StringProjectionOpBase(custom_options),
522         projection_normalizer_(
523             custom_options["token_separators"].AsString().str(),
524             custom_options["normalize_repetition"].AsBool()),
525         projection_tokenizer_(" ") {
526     if (custom_options["max_input_chars"].IsInt()) {
527       skip_gram_params().max_input_chars =
528           custom_options["max_input_chars"].AsInt32();
529     }
530   }
531 
InitializeInput(TfLiteContext * context,TfLiteNode * node)532   TfLiteStatus InitializeInput(TfLiteContext* context,
533                                TfLiteNode* node) override {
534     input_ = &context->tensors[node->inputs->data[kInputMessage]];
535     return kTfLiteOk;
536   }
537 
ExtractSkipGrams(int i)538   std::unordered_map<std::string, int> ExtractSkipGrams(int i) override {
539     StringRef input = GetString(input_, i);
540     return ::tflite::ops::custom::libtextclassifier3::string_projection::
541         ExtractSkipGrams({input.str, static_cast<size_t>(input.len)},
542                          &projection_tokenizer_, &projection_normalizer_,
543                          skip_gram_params());
544   }
545 
FinalizeInput()546   void FinalizeInput() override { input_ = nullptr; }
547 
GetInputShape(TfLiteContext * context,TfLiteNode * node)548   TfLiteIntArray* GetInputShape(TfLiteContext* context,
549                                 TfLiteNode* node) override {
550     return context->tensors[node->inputs->data[kInputMessage]].dims;
551   }
552 
553  private:
554   ProjectionNormalizer projection_normalizer_;
555   ProjectionTokenizer projection_tokenizer_;
556 
557   TfLiteTensor* input_;
558 };
559 
Init(TfLiteContext * context,const char * buffer,size_t length)560 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
561   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
562   return new StringProjectionOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
563 }
564 
565 }  // namespace string_projection
566 
567 // This op converts a list of strings to integers via LSH projections.
Register_STRING_PROJECTION()568 TfLiteRegistration* Register_STRING_PROJECTION() {
569   static TfLiteRegistration r = {libtextclassifier3::string_projection::Init,
570                                  libtextclassifier3::string_projection::Free,
571                                  libtextclassifier3::string_projection::Resize,
572                                  libtextclassifier3::string_projection::Eval};
573   return &r;
574 }
575 
576 }  // namespace libtextclassifier3
577 }  // namespace custom
578 }  // namespace ops
579 }  // namespace tflite
580