1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 19*993b0882SAndroid Build Coastguard Worker 20*993b0882SAndroid Build Coastguard Worker #include <string> 21*993b0882SAndroid Build Coastguard Worker #include <vector> 22*993b0882SAndroid Build Coastguard Worker 23*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h" 24*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h" 25*993b0882SAndroid Build Coastguard Worker #include "absl/strings/string_view.h" 26*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h" 27*993b0882SAndroid Build Coastguard Worker 28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 29*993b0882SAndroid Build Coastguard Worker 30*993b0882SAndroid Build Coastguard Worker // SkipgramFinder finds skipgrams in strings. 31*993b0882SAndroid Build Coastguard Worker // 32*993b0882SAndroid Build Coastguard Worker // To use: First, add skipgrams using AddSkipgram() - each skipgram is 33*993b0882SAndroid Build Coastguard Worker // associated with some category. Then, call FindSkipgrams() on a string, 34*993b0882SAndroid Build Coastguard Worker // which will return the set of categories of the skipgrams in the string. 35*993b0882SAndroid Build Coastguard Worker // 36*993b0882SAndroid Build Coastguard Worker // Both the skipgrams and the input strings will be tokenzied by splitting 37*993b0882SAndroid Build Coastguard Worker // on spaces. Additionally, the tokens will be lowercased and have any 38*993b0882SAndroid Build Coastguard Worker // trailing punctuation removed. 39*993b0882SAndroid Build Coastguard Worker class SkipgramFinder { 40*993b0882SAndroid Build Coastguard Worker public: SkipgramFinder(int max_skip_size)41*993b0882SAndroid Build Coastguard Worker explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {} 42*993b0882SAndroid Build Coastguard Worker 43*993b0882SAndroid Build Coastguard Worker // Adds a skipgram that SkipgramFinder should look for in input strings. 44*993b0882SAndroid Build Coastguard Worker // Tokens may use the regex '.*' as a suffix. 45*993b0882SAndroid Build Coastguard Worker void AddSkipgram(const std::string& skipgram, int category); 46*993b0882SAndroid Build Coastguard Worker 47*993b0882SAndroid Build Coastguard Worker // Find all of the skipgrams in `input`, and return their categories. 48*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> FindSkipgrams(const std::string& input) const; 49*993b0882SAndroid Build Coastguard Worker 50*993b0882SAndroid Build Coastguard Worker // Find all of the skipgrams in `tokens`, and return their categories. 51*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> FindSkipgrams( 52*993b0882SAndroid Build Coastguard Worker const std::vector<absl::string_view>& tokens) const; 53*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> FindSkipgrams( 54*993b0882SAndroid Build Coastguard Worker const std::vector<::tflite::StringRef>& tokens) const; 55*993b0882SAndroid Build Coastguard Worker 56*993b0882SAndroid Build Coastguard Worker private: 57*993b0882SAndroid Build Coastguard Worker struct TrieNode { 58*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> categories; 59*993b0882SAndroid Build Coastguard Worker // Maps tokens to the next node in the trie. 60*993b0882SAndroid Build Coastguard Worker absl::flat_hash_map<std::string, TrieNode> token_to_node; 61*993b0882SAndroid Build Coastguard Worker // Maps token prefixes (<prefix>.*) to the next node in the trie. 62*993b0882SAndroid Build Coastguard Worker absl::flat_hash_map<std::string, TrieNode> prefix_to_node; 63*993b0882SAndroid Build Coastguard Worker }; 64*993b0882SAndroid Build Coastguard Worker 65*993b0882SAndroid Build Coastguard Worker TrieNode skipgram_trie_; 66*993b0882SAndroid Build Coastguard Worker int max_skip_size_; 67*993b0882SAndroid Build Coastguard Worker }; 68*993b0882SAndroid Build Coastguard Worker 69*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 70*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 71