xref: /aosp_15_r20/external/libtextclassifier/native/actions/ngram-model.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/ngram-model.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/hash/farmhash.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker // An iterator to iterate over the initial tokens of the n-grams of a model.
29*993b0882SAndroid Build Coastguard Worker class FirstTokenIterator
30*993b0882SAndroid Build Coastguard Worker     : public std::iterator<std::random_access_iterator_tag,
31*993b0882SAndroid Build Coastguard Worker                            /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
32*993b0882SAndroid Build Coastguard Worker                            /*pointer=*/const uint32*,
33*993b0882SAndroid Build Coastguard Worker                            /*reference=*/uint32&> {
34*993b0882SAndroid Build Coastguard Worker  public:
FirstTokenIterator(const NGramLinearRegressionModel * model,int index)35*993b0882SAndroid Build Coastguard Worker   explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
36*993b0882SAndroid Build Coastguard Worker                               int index)
37*993b0882SAndroid Build Coastguard Worker       : model_(model), index_(index) {}
38*993b0882SAndroid Build Coastguard Worker 
operator ++()39*993b0882SAndroid Build Coastguard Worker   FirstTokenIterator& operator++() {
40*993b0882SAndroid Build Coastguard Worker     index_++;
41*993b0882SAndroid Build Coastguard Worker     return *this;
42*993b0882SAndroid Build Coastguard Worker   }
operator +=(ptrdiff_t dist)43*993b0882SAndroid Build Coastguard Worker   FirstTokenIterator& operator+=(ptrdiff_t dist) {
44*993b0882SAndroid Build Coastguard Worker     index_ += dist;
45*993b0882SAndroid Build Coastguard Worker     return *this;
46*993b0882SAndroid Build Coastguard Worker   }
operator -(const FirstTokenIterator & other_it) const47*993b0882SAndroid Build Coastguard Worker   ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
48*993b0882SAndroid Build Coastguard Worker     return index_ - other_it.index_;
49*993b0882SAndroid Build Coastguard Worker   }
operator *() const50*993b0882SAndroid Build Coastguard Worker   uint32 operator*() const {
51*993b0882SAndroid Build Coastguard Worker     const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
52*993b0882SAndroid Build Coastguard Worker     return (*model_->hashed_ngram_tokens())[token_offset];
53*993b0882SAndroid Build Coastguard Worker   }
index() const54*993b0882SAndroid Build Coastguard Worker   int index() const { return index_; }
55*993b0882SAndroid Build Coastguard Worker 
56*993b0882SAndroid Build Coastguard Worker  private:
57*993b0882SAndroid Build Coastguard Worker   const NGramLinearRegressionModel* model_;
58*993b0882SAndroid Build Coastguard Worker   int index_;
59*993b0882SAndroid Build Coastguard Worker };
60*993b0882SAndroid Build Coastguard Worker 
61*993b0882SAndroid Build Coastguard Worker }  // anonymous namespace
62*993b0882SAndroid Build Coastguard Worker 
Create(const UniLib * unilib,const NGramLinearRegressionModel * model,const Tokenizer * tokenizer)63*993b0882SAndroid Build Coastguard Worker std::unique_ptr<NGramSensitiveModel> NGramSensitiveModel::Create(
64*993b0882SAndroid Build Coastguard Worker     const UniLib* unilib, const NGramLinearRegressionModel* model,
65*993b0882SAndroid Build Coastguard Worker     const Tokenizer* tokenizer) {
66*993b0882SAndroid Build Coastguard Worker   if (model == nullptr) {
67*993b0882SAndroid Build Coastguard Worker     return nullptr;
68*993b0882SAndroid Build Coastguard Worker   }
69*993b0882SAndroid Build Coastguard Worker   if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
70*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No tokenizer options specified.";
71*993b0882SAndroid Build Coastguard Worker     return nullptr;
72*993b0882SAndroid Build Coastguard Worker   }
73*993b0882SAndroid Build Coastguard Worker   return std::unique_ptr<NGramSensitiveModel>(
74*993b0882SAndroid Build Coastguard Worker       new NGramSensitiveModel(unilib, model, tokenizer));
75*993b0882SAndroid Build Coastguard Worker }
76*993b0882SAndroid Build Coastguard Worker 
NGramSensitiveModel(const UniLib * unilib,const NGramLinearRegressionModel * model,const Tokenizer * tokenizer)77*993b0882SAndroid Build Coastguard Worker NGramSensitiveModel::NGramSensitiveModel(
78*993b0882SAndroid Build Coastguard Worker     const UniLib* unilib, const NGramLinearRegressionModel* model,
79*993b0882SAndroid Build Coastguard Worker     const Tokenizer* tokenizer)
80*993b0882SAndroid Build Coastguard Worker     : model_(model) {
81*993b0882SAndroid Build Coastguard Worker   // Create new tokenizer if options are specified, reuse feature processor
82*993b0882SAndroid Build Coastguard Worker   // tokenizer otherwise.
83*993b0882SAndroid Build Coastguard Worker   if (model->tokenizer_options() != nullptr) {
84*993b0882SAndroid Build Coastguard Worker     owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
85*993b0882SAndroid Build Coastguard Worker     tokenizer_ = owned_tokenizer_.get();
86*993b0882SAndroid Build Coastguard Worker   } else {
87*993b0882SAndroid Build Coastguard Worker     tokenizer_ = tokenizer;
88*993b0882SAndroid Build Coastguard Worker   }
89*993b0882SAndroid Build Coastguard Worker }
90*993b0882SAndroid Build Coastguard Worker 
91*993b0882SAndroid Build Coastguard Worker // Returns whether a given n-gram matches the token stream.
IsNGramMatch(const uint32 * tokens,size_t num_tokens,const uint32 * ngram_tokens,size_t num_ngram_tokens,int max_skips) const92*993b0882SAndroid Build Coastguard Worker bool NGramSensitiveModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
93*993b0882SAndroid Build Coastguard Worker                                        const uint32* ngram_tokens,
94*993b0882SAndroid Build Coastguard Worker                                        size_t num_ngram_tokens,
95*993b0882SAndroid Build Coastguard Worker                                        int max_skips) const {
96*993b0882SAndroid Build Coastguard Worker   int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
97*993b0882SAndroid Build Coastguard Worker   for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
98*993b0882SAndroid Build Coastguard Worker     if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
99*993b0882SAndroid Build Coastguard Worker       // Token matches. Advance both and reset the skip budget.
100*993b0882SAndroid Build Coastguard Worker       ++token_idx;
101*993b0882SAndroid Build Coastguard Worker       ++ngram_token_idx;
102*993b0882SAndroid Build Coastguard Worker       skip_remain = max_skips;
103*993b0882SAndroid Build Coastguard Worker     } else if (skip_remain > 0) {
104*993b0882SAndroid Build Coastguard Worker       // No match, but we have skips left, so just advance over the token.
105*993b0882SAndroid Build Coastguard Worker       ++token_idx;
106*993b0882SAndroid Build Coastguard Worker       skip_remain--;
107*993b0882SAndroid Build Coastguard Worker     } else {
108*993b0882SAndroid Build Coastguard Worker       // No match and we're out of skips. Reject.
109*993b0882SAndroid Build Coastguard Worker       return false;
110*993b0882SAndroid Build Coastguard Worker     }
111*993b0882SAndroid Build Coastguard Worker   }
112*993b0882SAndroid Build Coastguard Worker   return ngram_token_idx == num_ngram_tokens;
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker // Calculates the total number of skip-grams that can be created for a stream
116*993b0882SAndroid Build Coastguard Worker // with the given number of tokens.
GetNumSkipGrams(int num_tokens,int max_ngram_length,int max_skips)117*993b0882SAndroid Build Coastguard Worker uint64 NGramSensitiveModel::GetNumSkipGrams(int num_tokens,
118*993b0882SAndroid Build Coastguard Worker                                             int max_ngram_length,
119*993b0882SAndroid Build Coastguard Worker                                             int max_skips) {
120*993b0882SAndroid Build Coastguard Worker   // Start with unigrams.
121*993b0882SAndroid Build Coastguard Worker   uint64 total = num_tokens;
122*993b0882SAndroid Build Coastguard Worker   for (int ngram_len = 2;
123*993b0882SAndroid Build Coastguard Worker        ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
124*993b0882SAndroid Build Coastguard Worker     // We can easily compute the expected length of the n-gram (with skips),
125*993b0882SAndroid Build Coastguard Worker     // but it doesn't account for the fact that they may be longer than the
126*993b0882SAndroid Build Coastguard Worker     // input and should be pruned.
127*993b0882SAndroid Build Coastguard Worker     // Instead, we iterate over the distribution of effective n-gram lengths
128*993b0882SAndroid Build Coastguard Worker     // and add each length individually.
129*993b0882SAndroid Build Coastguard Worker     const int num_gaps = ngram_len - 1;
130*993b0882SAndroid Build Coastguard Worker     const int len_min = ngram_len;
131*993b0882SAndroid Build Coastguard Worker     const int len_max = ngram_len + num_gaps * max_skips;
132*993b0882SAndroid Build Coastguard Worker     const int len_mid = (len_max + len_min) / 2;
133*993b0882SAndroid Build Coastguard Worker     for (int len_i = len_min; len_i <= len_max; ++len_i) {
134*993b0882SAndroid Build Coastguard Worker       if (len_i > num_tokens) continue;
135*993b0882SAndroid Build Coastguard Worker       const int num_configs_of_len_i =
136*993b0882SAndroid Build Coastguard Worker           len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
137*993b0882SAndroid Build Coastguard Worker       const int num_start_offsets = num_tokens - len_i + 1;
138*993b0882SAndroid Build Coastguard Worker       total += num_configs_of_len_i * num_start_offsets;
139*993b0882SAndroid Build Coastguard Worker     }
140*993b0882SAndroid Build Coastguard Worker   }
141*993b0882SAndroid Build Coastguard Worker   return total;
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker 
GetFirstTokenMatches(uint32 token_hash) const144*993b0882SAndroid Build Coastguard Worker std::pair<int, int> NGramSensitiveModel::GetFirstTokenMatches(
145*993b0882SAndroid Build Coastguard Worker     uint32 token_hash) const {
146*993b0882SAndroid Build Coastguard Worker   const int num_ngrams = model_->ngram_weights()->size();
147*993b0882SAndroid Build Coastguard Worker   const auto start_it = FirstTokenIterator(model_, 0);
148*993b0882SAndroid Build Coastguard Worker   const auto end_it = FirstTokenIterator(model_, num_ngrams);
149*993b0882SAndroid Build Coastguard Worker   const int start = std::lower_bound(start_it, end_it, token_hash).index();
150*993b0882SAndroid Build Coastguard Worker   const int end = std::upper_bound(start_it, end_it, token_hash).index();
151*993b0882SAndroid Build Coastguard Worker   return std::make_pair(start, end);
152*993b0882SAndroid Build Coastguard Worker }
153*993b0882SAndroid Build Coastguard Worker 
Eval(const UnicodeText & text) const154*993b0882SAndroid Build Coastguard Worker std::pair<bool, float> NGramSensitiveModel::Eval(
155*993b0882SAndroid Build Coastguard Worker     const UnicodeText& text) const {
156*993b0882SAndroid Build Coastguard Worker   const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
157*993b0882SAndroid Build Coastguard Worker 
158*993b0882SAndroid Build Coastguard Worker   // If we have no tokens, then just bail early.
159*993b0882SAndroid Build Coastguard Worker   if (raw_tokens.empty()) {
160*993b0882SAndroid Build Coastguard Worker     return std::make_pair(false, model_->default_token_weight());
161*993b0882SAndroid Build Coastguard Worker   }
162*993b0882SAndroid Build Coastguard Worker 
163*993b0882SAndroid Build Coastguard Worker   // Hash the tokens.
164*993b0882SAndroid Build Coastguard Worker   std::vector<uint32> tokens;
165*993b0882SAndroid Build Coastguard Worker   tokens.reserve(raw_tokens.size());
166*993b0882SAndroid Build Coastguard Worker   for (const Token& raw_token : raw_tokens) {
167*993b0882SAndroid Build Coastguard Worker     tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
168*993b0882SAndroid Build Coastguard Worker                                                 raw_token.value.length()));
169*993b0882SAndroid Build Coastguard Worker   }
170*993b0882SAndroid Build Coastguard Worker 
171*993b0882SAndroid Build Coastguard Worker   // Calculate the total number of skip-grams that can be generated for the
172*993b0882SAndroid Build Coastguard Worker   // input text.
173*993b0882SAndroid Build Coastguard Worker   const uint64 num_candidates = GetNumSkipGrams(
174*993b0882SAndroid Build Coastguard Worker       tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
175*993b0882SAndroid Build Coastguard Worker 
176*993b0882SAndroid Build Coastguard Worker   // For each token, see whether it denotes the start of an n-gram in the model.
177*993b0882SAndroid Build Coastguard Worker   int num_matches = 0;
178*993b0882SAndroid Build Coastguard Worker   float weight_matches = 0.f;
179*993b0882SAndroid Build Coastguard Worker   for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
180*993b0882SAndroid Build Coastguard Worker     const std::pair<int, int> ngram_range =
181*993b0882SAndroid Build Coastguard Worker         GetFirstTokenMatches(tokens[start_i]);
182*993b0882SAndroid Build Coastguard Worker     for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
183*993b0882SAndroid Build Coastguard Worker          ++ngram_idx) {
184*993b0882SAndroid Build Coastguard Worker       const uint16 ngram_tokens_begin =
185*993b0882SAndroid Build Coastguard Worker           (*model_->ngram_start_offsets())[ngram_idx];
186*993b0882SAndroid Build Coastguard Worker       const uint16 ngram_tokens_end =
187*993b0882SAndroid Build Coastguard Worker           (*model_->ngram_start_offsets())[ngram_idx + 1];
188*993b0882SAndroid Build Coastguard Worker       if (IsNGramMatch(
189*993b0882SAndroid Build Coastguard Worker               /*tokens=*/tokens.data() + start_i,
190*993b0882SAndroid Build Coastguard Worker               /*num_tokens=*/tokens.size() - start_i,
191*993b0882SAndroid Build Coastguard Worker               /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
192*993b0882SAndroid Build Coastguard Worker                   ngram_tokens_begin,
193*993b0882SAndroid Build Coastguard Worker               /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
194*993b0882SAndroid Build Coastguard Worker               /*max_skips=*/model_->max_skips())) {
195*993b0882SAndroid Build Coastguard Worker         ++num_matches;
196*993b0882SAndroid Build Coastguard Worker         weight_matches += (*model_->ngram_weights())[ngram_idx];
197*993b0882SAndroid Build Coastguard Worker       }
198*993b0882SAndroid Build Coastguard Worker     }
199*993b0882SAndroid Build Coastguard Worker   }
200*993b0882SAndroid Build Coastguard Worker 
201*993b0882SAndroid Build Coastguard Worker   // Calculate the score.
202*993b0882SAndroid Build Coastguard Worker   const int num_misses = num_candidates - num_matches;
203*993b0882SAndroid Build Coastguard Worker   const float internal_score =
204*993b0882SAndroid Build Coastguard Worker       (weight_matches + (model_->default_token_weight() * num_misses)) /
205*993b0882SAndroid Build Coastguard Worker       num_candidates;
206*993b0882SAndroid Build Coastguard Worker   return std::make_pair(internal_score > model_->threshold(), internal_score);
207*993b0882SAndroid Build Coastguard Worker }
208*993b0882SAndroid Build Coastguard Worker 
EvalConversation(const Conversation & conversation,const int num_messages) const209*993b0882SAndroid Build Coastguard Worker std::pair<bool, float> NGramSensitiveModel::EvalConversation(
210*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const int num_messages) const {
211*993b0882SAndroid Build Coastguard Worker   float score = 0.0;
212*993b0882SAndroid Build Coastguard Worker   for (int i = 1; i <= num_messages; i++) {
213*993b0882SAndroid Build Coastguard Worker     const std::string& message =
214*993b0882SAndroid Build Coastguard Worker         conversation.messages[conversation.messages.size() - i].text;
215*993b0882SAndroid Build Coastguard Worker     const UnicodeText message_unicode(
216*993b0882SAndroid Build Coastguard Worker         UTF8ToUnicodeText(message, /*do_copy=*/false));
217*993b0882SAndroid Build Coastguard Worker     // Run ngram linear regression model.
218*993b0882SAndroid Build Coastguard Worker     const auto prediction = Eval(message_unicode);
219*993b0882SAndroid Build Coastguard Worker     if (prediction.first) {
220*993b0882SAndroid Build Coastguard Worker       return prediction;
221*993b0882SAndroid Build Coastguard Worker     }
222*993b0882SAndroid Build Coastguard Worker     score = std::max(score, prediction.second);
223*993b0882SAndroid Build Coastguard Worker   }
224*993b0882SAndroid Build Coastguard Worker   return std::make_pair(false, score);
225*993b0882SAndroid Build Coastguard Worker }
226*993b0882SAndroid Build Coastguard Worker 
227*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
228