1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/string_projection.h"
18
19 #include <string>
20 #include <unordered_map>
21
22 #include "utils/strings/utf8.h"
23 #include "utils/tflite/string_projection_base.h"
24 #include "utils/utf8/unilib-common.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "flatbuffers/flexbuffers.h"
28 #include "tensorflow/lite/context.h"
29 #include "tensorflow/lite/string_util.h"
30
31 namespace tflite {
32 namespace ops {
33 namespace custom {
34
35 namespace libtextclassifier3 {
36 namespace string_projection {
37 namespace {
38
39 const char kStartToken[] = "<S>";
40 const char kEndToken[] = "<E>";
41 const char kEmptyToken[] = "<S> <E>";
42 constexpr size_t kEntireString = SIZE_MAX;
43 constexpr size_t kAllTokens = SIZE_MAX;
44 constexpr int kInvalid = -1;
45
46 constexpr char kApostrophe = '\'';
47 constexpr char kSpace = ' ';
48 constexpr char kComma = ',';
49 constexpr char kDot = '.';
50
51 // Returns true if the given text contains a number.
IsDigitString(const std::string & text)52 bool IsDigitString(const std::string& text) {
53 for (size_t i = 0; i < text.length();) {
54 const int bytes_read =
55 ::libtextclassifier3::GetNumBytesForUTF8Char(text.data());
56 if (bytes_read <= 0 || bytes_read > text.length() - i) {
57 break;
58 }
59 const char32_t rune = ::libtextclassifier3::ValidCharToRune(text.data());
60 if (::libtextclassifier3::IsDigit(rune)) return true;
61 i += bytes_read;
62 }
63 return false;
64 }
65
66 // Gets the string containing |num_chars| characters from |start| position.
GetCharToken(const std::vector<std::string> & char_tokens,int start,int num_chars)67 std::string GetCharToken(const std::vector<std::string>& char_tokens, int start,
68 int num_chars) {
69 std::string char_token = "";
70 if (start + num_chars <= char_tokens.size()) {
71 for (int i = 0; i < num_chars; ++i) {
72 char_token.append(char_tokens[start + i]);
73 }
74 }
75 return char_token;
76 }
77
78 // Counts how many times |pattern| appeared from |start| position.
GetNumPattern(const std::vector<std::string> & char_tokens,size_t start,size_t num_chars,const std::string & pattern)79 int GetNumPattern(const std::vector<std::string>& char_tokens, size_t start,
80 size_t num_chars, const std::string& pattern) {
81 int count = 0;
82 for (int i = start; i < char_tokens.size(); i += num_chars) {
83 std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
84 if (pattern == cur_pattern) {
85 ++count;
86 } else {
87 break;
88 }
89 }
90 return count;
91 }
92
FindNextSpace(const char * input_ptr,size_t from,size_t length)93 inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
94 size_t space_index;
95 for (space_index = from; space_index < length; space_index++) {
96 if (input_ptr[space_index] == kSpace) {
97 break;
98 }
99 }
100 return space_index == length ? kInvalid : space_index;
101 }
102
103 template <typename T>
SplitByCharInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_tokens)104 void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
105 size_t len, size_t max_tokens) {
106 for (size_t i = 0; i < len;) {
107 auto bytes_read =
108 ::libtextclassifier3::GetNumBytesForUTF8Char(input_ptr + i);
109 if (bytes_read <= 0 || bytes_read > len - i) break;
110 tokens->emplace_back(input_ptr + i, bytes_read);
111 if (max_tokens != kInvalid && tokens->size() == max_tokens) {
112 break;
113 }
114 i += bytes_read;
115 }
116 }
117
SplitByChar(const char * input_ptr,size_t len,size_t max_tokens)118 std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
119 size_t max_tokens) {
120 std::vector<std::string> tokens;
121 SplitByCharInternal(&tokens, input_ptr, len, max_tokens);
122 return tokens;
123 }
124
ContractToken(const char * input_ptr,size_t len,size_t num_chars)125 std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
126 // This function contracts patterns whose length is |num_chars| and appeared
127 // more than twice. So if the input is shorter than 3 * |num_chars|, do not
128 // apply any contraction.
129 if (len < 3 * num_chars) {
130 return input_ptr;
131 }
132 std::vector<std::string> char_tokens = SplitByChar(input_ptr, len, len);
133
134 std::string token;
135 token.reserve(len);
136 for (int i = 0; i < char_tokens.size();) {
137 std::string cur_pattern = GetCharToken(char_tokens, i, num_chars);
138
139 // Count how many times this pattern appeared.
140 int num_cur_patterns = 0;
141 if (!absl::StrContains(cur_pattern, " ") && !IsDigitString(cur_pattern)) {
142 num_cur_patterns =
143 GetNumPattern(char_tokens, i + num_chars, num_chars, cur_pattern);
144 }
145
146 if (num_cur_patterns >= 2) {
147 // If this pattern is repeated, store it only twice.
148 token.append(cur_pattern);
149 token.append(cur_pattern);
150 i += (num_cur_patterns + 1) * num_chars;
151 } else {
152 token.append(char_tokens[i]);
153 ++i;
154 }
155 }
156
157 return token;
158 }
159
160 template <typename T>
SplitBySpaceInternal(std::vector<T> * tokens,const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)161 void SplitBySpaceInternal(std::vector<T>* tokens, const char* input_ptr,
162 size_t len, size_t max_input, size_t max_tokens) {
163 size_t last_index =
164 max_input == kEntireString ? len : (len < max_input ? len : max_input);
165 size_t start = 0;
166 // skip leading spaces
167 while (start < last_index && input_ptr[start] == kSpace) {
168 start++;
169 }
170 auto end = FindNextSpace(input_ptr, start, last_index);
171 while (end != kInvalid &&
172 (max_tokens == kAllTokens || tokens->size() < max_tokens - 1)) {
173 auto length = end - start;
174 if (length > 0) {
175 tokens->emplace_back(input_ptr + start, length);
176 }
177
178 start = end + 1;
179 end = FindNextSpace(input_ptr, start, last_index);
180 }
181 auto length = end == kInvalid ? (last_index - start) : (end - start);
182 if (length > 0) {
183 tokens->emplace_back(input_ptr + start, length);
184 }
185 }
186
SplitBySpace(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens)187 std::vector<std::string> SplitBySpace(const char* input_ptr, size_t len,
188 size_t max_input, size_t max_tokens) {
189 std::vector<std::string> tokens;
190 SplitBySpaceInternal(&tokens, input_ptr, len, max_input, max_tokens);
191 return tokens;
192 }
193
prepend_separator(char separator)194 bool prepend_separator(char separator) { return separator == kApostrophe; }
195
is_numeric(char c)196 bool is_numeric(char c) { return c >= '0' && c <= '9'; }
197
198 class ProjectionNormalizer {
199 public:
ProjectionNormalizer(const std::string & separators,bool normalize_repetition=false)200 explicit ProjectionNormalizer(const std::string& separators,
201 bool normalize_repetition = false) {
202 InitializeSeparators(separators);
203 normalize_repetition_ = normalize_repetition;
204 }
205
206 // Normalizes the repeated characters (except numbers) which consecutively
207 // appeared more than twice in a word.
Normalize(const std::string & input,size_t max_input=300)208 std::string Normalize(const std::string& input, size_t max_input = 300) {
209 return Normalize(input.data(), input.length(), max_input);
210 }
Normalize(const char * input_ptr,size_t len,size_t max_input=300)211 std::string Normalize(const char* input_ptr, size_t len,
212 size_t max_input = 300) {
213 std::string normalized(input_ptr, std::min(len, max_input));
214
215 if (normalize_repetition_) {
216 // Remove repeated 1 char (e.g. soooo => soo)
217 normalized = ContractToken(normalized.data(), normalized.length(), 1);
218
219 // Remove repeated 2 chars from the beginning (e.g. hahaha =>
220 // haha, xhahaha => xhaha, xyhahaha => xyhaha).
221 normalized = ContractToken(normalized.data(), normalized.length(), 2);
222
223 // Remove repeated 3 chars from the beginning
224 // (e.g. wowwowwow => wowwow, abcdbcdbcd => abcdbcd).
225 normalized = ContractToken(normalized.data(), normalized.length(), 3);
226 }
227
228 if (!separators_.empty()) {
229 // Add space around separators_.
230 normalized = NormalizeInternal(normalized.data(), normalized.length());
231 }
232 return normalized;
233 }
234
235 private:
236 // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)237 void InitializeSeparators(const std::string& separators) {
238 for (int i = 0; i < separators.length(); ++i) {
239 if (separators[i] != ' ') {
240 separators_.insert(separators[i]);
241 }
242 }
243 }
244
245 // Removes repeated chars.
NormalizeInternal(const char * input_ptr,size_t len)246 std::string NormalizeInternal(const char* input_ptr, size_t len) {
247 std::string normalized;
248 normalized.reserve(len * 2);
249 for (int i = 0; i < len; ++i) {
250 char c = input_ptr[i];
251 bool matched_separator = separators_.find(c) != separators_.end();
252 if (matched_separator) {
253 if (i > 0 && input_ptr[i - 1] != ' ' && normalized.back() != ' ') {
254 normalized.append(" ");
255 }
256 }
257 normalized.append(1, c);
258 if (matched_separator) {
259 if (i + 1 < len && input_ptr[i + 1] != ' ' && c != '\'') {
260 normalized.append(" ");
261 }
262 }
263 }
264 return normalized;
265 }
266
267 absl::flat_hash_set<char> separators_;
268 bool normalize_repetition_;
269 };
270
271 class ProjectionTokenizer {
272 public:
ProjectionTokenizer(const std::string & separators)273 explicit ProjectionTokenizer(const std::string& separators) {
274 InitializeSeparators(separators);
275 }
276
277 // Tokenizes the input by separators_. Limit to max_tokens, when it is not -1.
Tokenize(const std::string & input,size_t max_input,size_t max_tokens) const278 std::vector<std::string> Tokenize(const std::string& input, size_t max_input,
279 size_t max_tokens) const {
280 return Tokenize(input.c_str(), input.size(), max_input, max_tokens);
281 }
282
Tokenize(const char * input_ptr,size_t len,size_t max_input,size_t max_tokens) const283 std::vector<std::string> Tokenize(const char* input_ptr, size_t len,
284 size_t max_input, size_t max_tokens) const {
285 // If separators_ is not given, tokenize the input with a space.
286 if (separators_.empty()) {
287 return SplitBySpace(input_ptr, len, max_input, max_tokens);
288 }
289
290 std::vector<std::string> tokens;
291 size_t last_index =
292 max_input == kEntireString ? len : (len < max_input ? len : max_input);
293 size_t start = 0;
294 // Skip leading spaces.
295 while (start < last_index && input_ptr[start] == kSpace) {
296 start++;
297 }
298 auto end = FindNextSeparator(input_ptr, start, last_index);
299
300 while (end != kInvalid &&
301 (max_tokens == kAllTokens || tokens.size() < max_tokens - 1)) {
302 auto length = end - start;
303 if (length > 0) tokens.emplace_back(input_ptr + start, length);
304
305 // Add the separator (except space and apostrophe) as a token
306 char separator = input_ptr[end];
307 if (separator != kSpace && separator != kApostrophe) {
308 tokens.emplace_back(input_ptr + end, 1);
309 }
310
311 start = end + (prepend_separator(separator) ? 0 : 1);
312 end = FindNextSeparator(input_ptr, end + 1, last_index);
313 }
314 auto length = end == kInvalid ? (last_index - start) : (end - start);
315 if (length > 0) tokens.emplace_back(input_ptr + start, length);
316 return tokens;
317 }
318
319 private:
320 // Parses and extracts supported separators.
InitializeSeparators(const std::string & separators)321 void InitializeSeparators(const std::string& separators) {
322 for (int i = 0; i < separators.length(); ++i) {
323 separators_.insert(separators[i]);
324 }
325 }
326
327 // Starting from input_ptr[from], search for the next occurrence of
328 // separators_. Don't search beyond input_ptr[length](non-inclusive). Return
329 // -1 if not found.
FindNextSeparator(const char * input_ptr,size_t from,size_t length) const330 size_t FindNextSeparator(const char* input_ptr, size_t from,
331 size_t length) const {
332 auto index = from;
333 while (index < length) {
334 char c = input_ptr[index];
335 // Do not break a number (e.g. "10,000", "0.23").
336 if (c == kComma || c == kDot) {
337 if (index + 1 < length && is_numeric(input_ptr[index + 1])) {
338 c = input_ptr[++index];
339 }
340 }
341 if (separators_.find(c) != separators_.end()) {
342 break;
343 }
344 ++index;
345 }
346 return index == length ? kInvalid : index;
347 }
348
349 absl::flat_hash_set<char> separators_;
350 };
351
StripTrailingAsciiPunctuation(std::string * str)352 inline void StripTrailingAsciiPunctuation(std::string* str) {
353 auto it = std::find_if_not(str->rbegin(), str->rend(), ::ispunct);
354 str->erase(str->rend() - it);
355 }
356
PreProcessString(const char * str,int len,const bool remove_punctuation)357 std::string PreProcessString(const char* str, int len,
358 const bool remove_punctuation) {
359 std::string output_str(str, len);
360 std::transform(output_str.begin(), output_str.end(), output_str.begin(),
361 ::tolower);
362
363 // Remove trailing punctuation.
364 if (remove_punctuation) {
365 StripTrailingAsciiPunctuation(&output_str);
366 }
367
368 if (output_str.empty()) {
369 output_str.assign(str, len);
370 }
371 return output_str;
372 }
373
ShouldIncludeCurrentNgram(const SkipGramParams & params,int size)374 bool ShouldIncludeCurrentNgram(const SkipGramParams& params, int size) {
375 if (size <= 0) {
376 return false;
377 }
378 if (params.include_all_ngrams) {
379 return size <= params.ngram_size;
380 } else {
381 return size == params.ngram_size;
382 }
383 }
384
ShouldStepInRecursion(const std::vector<int> & stack,int stack_idx,int num_words,const SkipGramParams & params)385 bool ShouldStepInRecursion(const std::vector<int>& stack, int stack_idx,
386 int num_words, const SkipGramParams& params) {
387 // If current stack size and next word enumeration are within valid range.
388 if (stack_idx < params.ngram_size && stack[stack_idx] + 1 < num_words) {
389 // If this stack is empty, step in for first word enumeration.
390 if (stack_idx == 0) {
391 return true;
392 }
393 // If next word enumeration are within the range of max_skip_size.
394 // NOTE: equivalent to
395 // next_word_idx = stack[stack_idx] + 1
396 // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
397 if (stack[stack_idx] - stack[stack_idx - 1] <= params.max_skip_size) {
398 return true;
399 }
400 }
401 return false;
402 }
403
JoinTokensBySpace(const std::vector<int> & stack,int stack_idx,const std::vector<std::string> & tokens)404 std::string JoinTokensBySpace(const std::vector<int>& stack, int stack_idx,
405 const std::vector<std::string>& tokens) {
406 int len = 0;
407 for (int i = 0; i < stack_idx; i++) {
408 len += tokens[stack[i]].size();
409 }
410 len += stack_idx - 1;
411
412 std::string res;
413 res.reserve(len);
414 res.append(tokens[stack[0]]);
415 for (int i = 1; i < stack_idx; i++) {
416 res.append(" ");
417 res.append(tokens[stack[i]]);
418 }
419
420 return res;
421 }
422
ExtractSkipGramsImpl(const std::vector<std::string> & tokens,const SkipGramParams & params)423 std::unordered_map<std::string, int> ExtractSkipGramsImpl(
424 const std::vector<std::string>& tokens, const SkipGramParams& params) {
425 // Ignore positional tokens.
426 static auto* blacklist = new std::unordered_set<std::string>({
427 kStartToken,
428 kEndToken,
429 kEmptyToken,
430 });
431
432 std::unordered_map<std::string, int> res;
433
434 // Stack stores the index of word used to generate ngram.
435 // The size of stack is the size of ngram.
436 std::vector<int> stack(params.ngram_size + 1, 0);
437 // Stack index that indicates which depth the recursion is operating at.
438 int stack_idx = 1;
439 int num_words = tokens.size();
440
441 while (stack_idx >= 0) {
442 if (ShouldStepInRecursion(stack, stack_idx, num_words, params)) {
443 // When current depth can fill with a new word
444 // and the new word is within the max range to skip,
445 // fill this word to stack, recurse into next depth.
446 stack[stack_idx]++;
447 stack_idx++;
448 stack[stack_idx] = stack[stack_idx - 1];
449 } else {
450 if (ShouldIncludeCurrentNgram(params, stack_idx)) {
451 // Add n-gram to tensor buffer when the stack has filled with enough
452 // words to generate the ngram.
453 std::string ngram = JoinTokensBySpace(stack, stack_idx, tokens);
454 if (blacklist->find(ngram) == blacklist->end()) {
455 res[ngram] = stack_idx;
456 }
457 }
458 // When current depth cannot fill with a valid new word,
459 // and not in last depth to generate ngram,
460 // step back to previous depth to iterate to next possible word.
461 stack_idx--;
462 }
463 }
464
465 return res;
466 }
467
ExtractSkipGrams(const std::string & input,ProjectionTokenizer * tokenizer,ProjectionNormalizer * normalizer,const SkipGramParams & params)468 std::unordered_map<std::string, int> ExtractSkipGrams(
469 const std::string& input, ProjectionTokenizer* tokenizer,
470 ProjectionNormalizer* normalizer, const SkipGramParams& params) {
471 // Normalize the input.
472 const std::string& normalized =
473 normalizer == nullptr
474 ? input
475 : normalizer->Normalize(input, params.max_input_chars);
476
477 // Split sentence to words.
478 std::vector<std::string> tokens;
479 if (params.char_level) {
480 tokens = SplitByChar(normalized.data(), normalized.size(),
481 params.max_input_chars);
482 } else {
483 tokens = tokenizer->Tokenize(normalized.data(), normalized.size(),
484 params.max_input_chars, kAllTokens);
485 }
486
487 // Process tokens
488 for (int i = 0; i < tokens.size(); ++i) {
489 if (params.preprocess) {
490 tokens[i] = PreProcessString(tokens[i].data(), tokens[i].size(),
491 params.remove_punctuation);
492 }
493 }
494
495 tokens.insert(tokens.begin(), kStartToken);
496 tokens.insert(tokens.end(), kEndToken);
497
498 return ExtractSkipGramsImpl(tokens, params);
499 }
500 } // namespace
501 // Generates LSH projections for input strings. This uses the framework in
502 // `string_projection_base.h`, with the implementation details that the input is
503 // a string tensor of messages and the op will perform tokenization.
504 //
505 // Input:
506 // tensor[0]: Input message, string[...]
507 //
508 // Additional attributes:
509 // max_input_chars: int[]
510 // maximum number of input characters to use from each message.
511 // token_separators: string[]
512 // the list of separators used to tokenize the input.
513 // normalize_repetition: bool[]
514 // if true, remove repeated characters in tokens ('loool' -> 'lol').
515
516 static const int kInputMessage = 0;
517
518 class StringProjectionOp : public StringProjectionOpBase {
519 public:
StringProjectionOp(const flexbuffers::Map & custom_options)520 explicit StringProjectionOp(const flexbuffers::Map& custom_options)
521 : StringProjectionOpBase(custom_options),
522 projection_normalizer_(
523 custom_options["token_separators"].AsString().str(),
524 custom_options["normalize_repetition"].AsBool()),
525 projection_tokenizer_(" ") {
526 if (custom_options["max_input_chars"].IsInt()) {
527 skip_gram_params().max_input_chars =
528 custom_options["max_input_chars"].AsInt32();
529 }
530 }
531
InitializeInput(TfLiteContext * context,TfLiteNode * node)532 TfLiteStatus InitializeInput(TfLiteContext* context,
533 TfLiteNode* node) override {
534 input_ = &context->tensors[node->inputs->data[kInputMessage]];
535 return kTfLiteOk;
536 }
537
ExtractSkipGrams(int i)538 std::unordered_map<std::string, int> ExtractSkipGrams(int i) override {
539 StringRef input = GetString(input_, i);
540 return ::tflite::ops::custom::libtextclassifier3::string_projection::
541 ExtractSkipGrams({input.str, static_cast<size_t>(input.len)},
542 &projection_tokenizer_, &projection_normalizer_,
543 skip_gram_params());
544 }
545
FinalizeInput()546 void FinalizeInput() override { input_ = nullptr; }
547
GetInputShape(TfLiteContext * context,TfLiteNode * node)548 TfLiteIntArray* GetInputShape(TfLiteContext* context,
549 TfLiteNode* node) override {
550 return context->tensors[node->inputs->data[kInputMessage]].dims;
551 }
552
553 private:
554 ProjectionNormalizer projection_normalizer_;
555 ProjectionTokenizer projection_tokenizer_;
556
557 TfLiteTensor* input_;
558 };
559
Init(TfLiteContext * context,const char * buffer,size_t length)560 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
561 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
562 return new StringProjectionOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
563 }
564
565 } // namespace string_projection
566
567 // This op converts a list of strings to integers via LSH projections.
Register_STRING_PROJECTION()568 TfLiteRegistration* Register_STRING_PROJECTION() {
569 static TfLiteRegistration r = {libtextclassifier3::string_projection::Init,
570 libtextclassifier3::string_projection::Free,
571 libtextclassifier3::string_projection::Resize,
572 libtextclassifier3::string_projection::Eval};
573 return &r;
574 }
575
576 } // namespace libtextclassifier3
577 } // namespace custom
578 } // namespace ops
579 } // namespace tflite
580