1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/utils/ir.h"
18
19 #include <algorithm>
20
21 #include "utils/i18n/locale.h"
22 #include "utils/strings/append.h"
23 #include "utils/strings/stringpiece.h"
24 #include "utils/zlib/zlib.h"
25
26 namespace libtextclassifier3::grammar {
27 namespace {
28
29 constexpr size_t kMaxHashTableSize = 100;
30
31 template <typename T>
SortForBinarySearchLookup(T * entries)32 void SortForBinarySearchLookup(T* entries) {
33 std::stable_sort(
34 entries->begin(), entries->end(),
35 [](const auto& a, const auto& b) { return a->key < b->key; });
36 }
37
38 template <typename T>
SortStructsForBinarySearchLookup(T * entries)39 void SortStructsForBinarySearchLookup(T* entries) {
40 std::stable_sort(
41 entries->begin(), entries->end(),
42 [](const auto& a, const auto& b) { return a.key() < b.key(); });
43 }
44
IsSameLhs(const Ir::Lhs & lhs,const RulesSet_::Lhs & other)45 bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
46 return (lhs.nonterminal == other.nonterminal() &&
47 lhs.callback.id == other.callback_id() &&
48 lhs.callback.param == other.callback_param() &&
49 lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
50 }
51
IsSameLhsEntry(const Ir::Lhs & lhs,const int32 lhs_entry,const std::vector<RulesSet_::Lhs> & candidates)52 bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
53 const std::vector<RulesSet_::Lhs>& candidates) {
54 // Simple case: direct encoding of the nonterminal.
55 if (lhs_entry > 0) {
56 return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
57 lhs.preconditions.max_whitespace_gap == -1);
58 }
59
60 // Entry is index into callback lookup.
61 return IsSameLhs(lhs, candidates[-lhs_entry]);
62 }
63
IsSameLhsSet(const Ir::LhsSet & lhs_set,const RulesSet_::LhsSetT & candidate,const std::vector<RulesSet_::Lhs> & candidates)64 bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
65 const RulesSet_::LhsSetT& candidate,
66 const std::vector<RulesSet_::Lhs>& candidates) {
67 if (lhs_set.size() != candidate.lhs.size()) {
68 return false;
69 }
70
71 for (int i = 0; i < lhs_set.size(); i++) {
72 // Check that entries are the same.
73 if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
74 return false;
75 }
76 }
77
78 return true;
79 }
80
SortedLhsSet(const Ir::LhsSet & lhs_set)81 Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
82 Ir::LhsSet sorted_lhs = lhs_set;
83 std::stable_sort(
84 sorted_lhs.begin(), sorted_lhs.end(),
85 [](const Ir::Lhs& a, const Ir::Lhs& b) {
86 return std::tie(a.nonterminal, a.callback.id, a.callback.param,
87 a.preconditions.max_whitespace_gap) <
88 std::tie(b.nonterminal, b.callback.id, b.callback.param,
89 b.preconditions.max_whitespace_gap);
90 });
91 return lhs_set;
92 }
93
94 // Adds a new lhs match set to the output.
95 // Reuses the same set, if it was previously observed.
AddLhsSet(const Ir::LhsSet & lhs_set,RulesSetT * rules_set)96 int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
97 Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
98 // Check whether we can reuse an entry.
99 const int output_size = rules_set->lhs_set.size();
100 for (int i = 0; i < output_size; i++) {
101 if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
102 return i;
103 }
104 }
105
106 // Add new entry.
107 rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
108 RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
109 for (const Ir::Lhs& lhs : lhs_set) {
110 // Simple case: No callback and no special requirements, we directly encode
111 // the nonterminal.
112 if (lhs.callback.id == kNoCallback &&
113 lhs.preconditions.max_whitespace_gap < 0) {
114 serialized_lhs_set->lhs.push_back(lhs.nonterminal);
115 } else {
116 // Check whether we can reuse a callback entry.
117 const int lhs_size = rules_set->lhs.size();
118 bool found_entry = false;
119 for (int i = 0; i < lhs_size; i++) {
120 if (IsSameLhs(lhs, rules_set->lhs[i])) {
121 found_entry = true;
122 serialized_lhs_set->lhs.push_back(-i);
123 break;
124 }
125 }
126
127 // We could reuse an existing entry.
128 if (found_entry) {
129 continue;
130 }
131
132 // Add a new one.
133 rules_set->lhs.push_back(
134 RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
135 lhs.preconditions.max_whitespace_gap));
136 serialized_lhs_set->lhs.push_back(-lhs_size);
137 }
138 }
139 return output_size;
140 }
141
142 // Serializes a unary rules table.
SerializeUnaryRulesShard(const std::unordered_map<Nonterm,Ir::LhsSet> & unary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)143 void SerializeUnaryRulesShard(
144 const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
145 RulesSetT* rules_set, RulesSet_::RulesT* rules) {
146 for (const auto& it : unary_rules) {
147 rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
148 it.first, AddLhsSet(it.second, rules_set)));
149 }
150 SortStructsForBinarySearchLookup(&rules->unary_rules);
151 }
152
153 // // Serializes a binary rules table.
SerializeBinaryRulesShard(const std::unordered_map<TwoNonterms,Ir::LhsSet,BinaryRuleHasher> & binary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)154 void SerializeBinaryRulesShard(
155 const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
156 binary_rules,
157 RulesSetT* rules_set, RulesSet_::RulesT* rules) {
158 const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
159 for (int i = 0; i < num_buckets; i++) {
160 rules->binary_rules.emplace_back(
161 new RulesSet_::Rules_::BinaryRuleTableBucketT());
162 }
163
164 // Serialize the table.
165 BinaryRuleHasher hash;
166 for (const auto& it : binary_rules) {
167 const TwoNonterms key = it.first;
168 uint32 bucket_index = hash(key) % num_buckets;
169
170 // Add entry to bucket chain list.
171 rules->binary_rules[bucket_index]->rules.push_back(
172 RulesSet_::Rules_::BinaryRule(key.first, key.second,
173 AddLhsSet(it.second, rules_set)));
174 }
175 }
176
177 } // namespace
178
AddToSet(const Lhs & lhs,LhsSet * lhs_set)179 Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
180 const int lhs_set_size = lhs_set->size();
181 Nonterm shareable_nonterm = lhs.nonterminal;
182 for (int i = 0; i < lhs_set_size; i++) {
183 Lhs* candidate = &lhs_set->at(i);
184
185 // Exact match, just reuse rule.
186 if (lhs == *candidate) {
187 return candidate->nonterminal;
188 }
189
190 // Cannot reuse unshareable ids.
191 if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
192 nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
193 continue;
194 }
195
196 // Cannot reuse id if the preconditions are different.
197 if (!(lhs.preconditions == candidate->preconditions)) {
198 continue;
199 }
200
201 // If the nonterminal is already defined, it must match for sharing.
202 if (lhs.nonterminal != kUnassignedNonterm &&
203 lhs.nonterminal != candidate->nonterminal) {
204 continue;
205 }
206
207 // Check whether the callbacks match.
208 if (lhs.callback == candidate->callback) {
209 return candidate->nonterminal;
210 }
211
212 // We can reuse if one of the output callbacks is not used.
213 if (lhs.callback.id == kNoCallback) {
214 return candidate->nonterminal;
215 } else if (candidate->callback.id == kNoCallback) {
216 // Old entry has no output callback, which is redundant now.
217 candidate->callback = lhs.callback;
218 return candidate->nonterminal;
219 }
220
221 // We can share the nonterminal, but we need to
222 // add a new output callback. Defer this as we might find a shareable
223 // nonterminal first.
224 shareable_nonterm = candidate->nonterminal;
225 }
226
227 // We didn't find a redundant entry, so create a new one.
228 shareable_nonterm = DefineNonterminal(shareable_nonterm);
229 lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
230 return shareable_nonterm;
231 }
232
Add(const Lhs & lhs,const std::string & terminal,const bool case_sensitive,const int shard)233 Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
234 const bool case_sensitive, const int shard) {
235 TC3_CHECK_LT(shard, shards_.size());
236 if (case_sensitive) {
237 return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
238 } else {
239 return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
240 }
241 }
242
243 // For latency we put sub-rules on the first shard which must be any match
244 // i.e. '*' rules are always included while parsing the tree as it is only
245 // on shard one hence will be deduped correctly.
Add(const Lhs & lhs,const std::vector<Nonterm> & rhs,const int shard)246 Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
247 const int shard) {
248 // Add a new unary rule.
249 if (rhs.size() == 1) {
250 return Add(lhs, rhs.front(), shard);
251 }
252
253 // Add a chain of (rhs.size() - 1) binary rules.
254 Nonterm prev = rhs.front();
255 for (int i = 1; i < rhs.size() - 1; i++) {
256 prev = Add(kUnassignedNonterm, prev, rhs[i]);
257 }
258 return Add(lhs, prev, rhs.back());
259 }
260
AddRegex(Nonterm lhs,const std::string & regex_pattern)261 Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
262 lhs = DefineNonterminal(lhs);
263 regex_rules_.emplace_back(regex_pattern, lhs);
264 return lhs;
265 }
266
AddAnnotation(const Nonterm lhs,const std::string & annotation)267 void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
268 annotations_.emplace_back(annotation, lhs);
269 }
270
271 // Serializes the terminal rules table.
SerializeTerminalRules(RulesSetT * rules_set,std::vector<std::unique_ptr<RulesSet_::RulesT>> * rules_shards) const272 void Ir::SerializeTerminalRules(
273 RulesSetT* rules_set,
274 std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
275 // Use common pool for all terminals.
276 struct TerminalEntry {
277 std::string terminal;
278 int set_index;
279 int index;
280 Ir::LhsSet lhs_set;
281 };
282 std::vector<TerminalEntry> terminal_rules;
283
284 // Merge all terminals into a common pool.
285 // We want to use one common pool, but still need to track which set they
286 // belong to.
287 std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
288 terminal_rules_sets;
289 std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
290 terminal_rules_sets.reserve(2 * shards_.size());
291 rules_maps.reserve(terminal_rules_sets.size());
292 for (int i = 0; i < shards_.size(); i++) {
293 terminal_rules_sets.push_back(&shards_[i].terminal_rules);
294 terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
295 rules_shards->at(i)->terminal_rules.reset(
296 new RulesSet_::Rules_::TerminalRulesMapT());
297 rules_shards->at(i)->lowercase_terminal_rules.reset(
298 new RulesSet_::Rules_::TerminalRulesMapT());
299 rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
300 rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
301 }
302 for (int i = 0; i < terminal_rules_sets.size(); i++) {
303 for (const auto& it : *terminal_rules_sets[i]) {
304 terminal_rules.push_back(
305 TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
306 }
307 }
308 std::stable_sort(terminal_rules.begin(), terminal_rules.end(),
309 [](const TerminalEntry& a, const TerminalEntry& b) {
310 return a.terminal < b.terminal;
311 });
312
313 // Index the entries in sorted order.
314 std::vector<int> index(terminal_rules_sets.size(), 0);
315 for (int i = 0; i < terminal_rules.size(); i++) {
316 terminal_rules[i].index = index[terminal_rules[i].set_index]++;
317 }
318
319 // We store the terminal strings sorted into a buffer and keep offsets into
320 // that buffer. In this way, we don't need extra space for terminals that are
321 // suffixes of others.
322
323 // Find terminals that are a suffix of others, O(n^2) algorithm.
324 constexpr int kInvalidIndex = -1;
325 std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
326 for (int i = 0; i < terminal_rules.size(); i++) {
327 const StringPiece terminal(terminal_rules[i].terminal);
328
329 // Check whether the ith terminal is a suffix of another.
330 for (int j = 0; j < terminal_rules.size(); j++) {
331 if (i == j) {
332 continue;
333 }
334 if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
335 // If both terminals are the same keep the first.
336 // This avoids cyclic dependencies.
337 // This can happen if multiple shards use same terminals, such as
338 // punctuation.
339 if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
340 continue;
341 }
342 suffix[i] = j;
343 break;
344 }
345 }
346 }
347
348 rules_set->terminals = "";
349
350 for (int i = 0; i < terminal_rules_sets.size(); i++) {
351 rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
352 rules_maps[i]->max_terminal_length = 0;
353 rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
354 }
355
356 for (int i = 0; i < terminal_rules.size(); i++) {
357 const TerminalEntry& entry = terminal_rules[i];
358
359 // Update bounds.
360 rules_maps[entry.set_index]->min_terminal_length =
361 std::min(rules_maps[entry.set_index]->min_terminal_length,
362 static_cast<int>(entry.terminal.size()));
363 rules_maps[entry.set_index]->max_terminal_length =
364 std::max(rules_maps[entry.set_index]->max_terminal_length,
365 static_cast<int>(entry.terminal.size()));
366
367 // Only include terminals that are not suffixes of others.
368 if (suffix[i] != kInvalidIndex) {
369 continue;
370 }
371
372 rules_maps[entry.set_index]->terminal_offsets[entry.index] =
373 rules_set->terminals.length();
374 rules_set->terminals += entry.terminal + '\0';
375 }
376
377 // Store just an offset into the existing terminal data for the terminals
378 // that are suffixes of others.
379 for (int i = 0; i < terminal_rules.size(); i++) {
380 int canonical_index = i;
381 if (suffix[canonical_index] == kInvalidIndex) {
382 continue;
383 }
384
385 // Find the overlapping string that was included in the data.
386 while (suffix[canonical_index] != kInvalidIndex) {
387 canonical_index = suffix[canonical_index];
388 }
389
390 const TerminalEntry& entry = terminal_rules[i];
391 const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
392
393 // The offset is the offset of the overlapping string and the offset within
394 // that string.
395 rules_maps[entry.set_index]->terminal_offsets[entry.index] =
396 rules_maps[canonical_entry.set_index]
397 ->terminal_offsets[canonical_entry.index] +
398 (canonical_entry.terminal.length() - entry.terminal.length());
399 }
400
401 for (const TerminalEntry& entry : terminal_rules) {
402 rules_maps[entry.set_index]->lhs_set_index.push_back(
403 AddLhsSet(entry.lhs_set, rules_set));
404 }
405 }
406
Serialize(const bool include_debug_information,RulesSetT * output) const407 void Ir::Serialize(const bool include_debug_information,
408 RulesSetT* output) const {
409 // Add information about predefined nonterminal classes.
410 output->nonterminals.reset(new RulesSet_::NonterminalsT);
411 output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
412 output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
413 output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
414 output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
415 output->nonterminals->uppercase_token_nt =
416 GetNonterminalForName(kUppercaseTokenNonterm);
417 output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
418 for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
419 if (const Nonterm n_digits_nt =
420 GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
421 output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
422 output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
423 }
424 }
425 for (const auto& [annotation, annotation_nt] : annotations_) {
426 output->nonterminals->annotation_nt.emplace_back(
427 new RulesSet_::Nonterminals_::AnnotationNtEntryT);
428 output->nonterminals->annotation_nt.back()->key = annotation;
429 output->nonterminals->annotation_nt.back()->value = annotation_nt;
430 }
431 SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
432
433 if (include_debug_information) {
434 output->debug_information.reset(new RulesSet_::DebugInformationT);
435 // Keep original non-terminal names.
436 for (const auto& it : nonterminal_names_) {
437 output->debug_information->nonterminal_names.emplace_back(
438 new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
439 output->debug_information->nonterminal_names.back()->key = it.first;
440 output->debug_information->nonterminal_names.back()->value = it.second;
441 }
442 SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
443 }
444
445 // Add regex rules.
446 std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
447 for (auto [pattern, lhs] : regex_rules_) {
448 output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
449 output->regex_annotator.back()->compressed_pattern.reset(
450 new CompressedBufferT);
451 compressor->Compress(
452 pattern, output->regex_annotator.back()->compressed_pattern.get());
453 output->regex_annotator.back()->nonterminal = lhs;
454 }
455
456 // Serialize the unary and binary rules.
457 for (int i = 0; i < shards_.size(); i++) {
458 output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
459 RulesSet_::RulesT* rules = output->rules.back().get();
460 for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
461 if (shard_locale.IsValid()) {
462 // Check if the language is set to all i.e. '*' which is a special, to
463 // make it consistent with device side parser here instead of filling
464 // the all locale leave the language tag list empty
465 rules->locale.emplace_back(
466 std::make_unique<libtextclassifier3::LanguageTagT>());
467 libtextclassifier3::LanguageTagT* language_tag =
468 rules->locale.back().get();
469 language_tag->language = shard_locale.Language();
470 language_tag->region = shard_locale.Region();
471 language_tag->script = shard_locale.Script();
472 }
473 }
474
475 // Serialize the unary rules.
476 SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
477 // Serialize the binary rules.
478 SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
479 }
480 // Serialize the terminal rules.
481 // We keep the rules separate by shard but merge the actual terminals into
482 // one shared string pool to most effectively exploit reuse.
483 SerializeTerminalRules(output, &output->rules);
484 }
485
SerializeAsFlatbuffer(const bool include_debug_information) const486 std::string Ir::SerializeAsFlatbuffer(
487 const bool include_debug_information) const {
488 RulesSetT output;
489 Serialize(include_debug_information, &output);
490 flatbuffers::FlatBufferBuilder builder;
491 builder.Finish(RulesSet::Pack(builder, &output));
492 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
493 builder.GetSize());
494 }
495
496 } // namespace libtextclassifier3::grammar
497