xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/parsing/parser.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/parsing/parser.h"
18 
19 #include <algorithm>
20 #include <unordered_map>
21 
22 #include "utils/grammar/parsing/parse-tree.h"
23 #include "utils/grammar/rules-utils.h"
24 #include "utils/grammar/types.h"
25 #include "utils/zlib/zlib.h"
26 #include "utils/zlib/zlib_regex.h"
27 
28 namespace libtextclassifier3::grammar {
29 namespace {
30 
CheckMemoryUsage(const UnsafeArena * arena)31 inline bool CheckMemoryUsage(const UnsafeArena* arena) {
32   // The maximum memory usage for matching.
33   constexpr int kMaxMemoryUsage = 1 << 20;
34   return arena->status().bytes_allocated() <= kMaxMemoryUsage;
35 }
36 
37 // Maps a codepoint to include the token padding if it aligns with a token
38 // start. Whitespace is ignored when symbols are fed to the matcher. Preceding
39 // whitespace is merged to the match start so that tokens and non-terminals
40 // appear next to each other without whitespace. For text or regex annotations,
41 // we therefore merge the whitespace padding to the start if the annotation
42 // starts at a token.
MapCodepointToTokenPaddingIfPresent(const std::unordered_map<CodepointIndex,CodepointIndex> & token_alignment,const int start)43 int MapCodepointToTokenPaddingIfPresent(
44     const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
45     const int start) {
46   const auto it = token_alignment.find(start);
47   if (it != token_alignment.end()) {
48     return it->second;
49   }
50   return start;
51 }
52 
53 }  // namespace
54 
Parser(const UniLib * unilib,const RulesSet * rules)55 Parser::Parser(const UniLib* unilib, const RulesSet* rules)
56     : unilib_(*unilib),
57       rules_(rules),
58       lexer_(unilib),
59       nonterminals_(rules_->nonterminals()),
60       rules_locales_(ParseRulesLocales(rules_)),
61       regex_annotators_(BuildRegexAnnotators()) {}
62 
63 // Uncompresses and build the defined regex annotators.
BuildRegexAnnotators() const64 std::vector<Parser::RegexAnnotator> Parser::BuildRegexAnnotators() const {
65   std::vector<RegexAnnotator> result;
66   if (rules_->regex_annotator() != nullptr) {
67     std::unique_ptr<ZlibDecompressor> decompressor =
68         ZlibDecompressor::Instance();
69     result.reserve(rules_->regex_annotator()->size());
70     for (const RulesSet_::RegexAnnotator* regex_annotator :
71          *rules_->regex_annotator()) {
72       result.push_back(
73           {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
74                                       regex_annotator->compressed_pattern(),
75                                       rules_->lazy_regex_compilation(),
76                                       decompressor.get()),
77            regex_annotator->nonterminal()});
78     }
79   }
80   return result;
81 }
82 
SortedSymbolsForInput(const TextContext & input,UnsafeArena * arena) const83 std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
84                                                   UnsafeArena* arena) const {
85   // Whitespace is ignored when symbols are fed to the matcher.
86   // For regex matches and existing text annotations we therefore have to merge
87   // preceding whitespace to the match start so that tokens and non-terminals
88   // appear as next to each other without whitespace. We keep track of real
89   // token starts and precending whitespace in `token_match_start`, so that we
90   // can extend a match's start to include the preceding whitespace.
91   std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
92   for (int i = input.context_span.first + 1; i < input.context_span.second;
93        i++) {
94     const CodepointIndex token_start = input.tokens[i].start;
95     const CodepointIndex prev_token_end = input.tokens[i - 1].end;
96     if (token_start != prev_token_end) {
97       token_match_start[token_start] = prev_token_end;
98     }
99   }
100 
101   std::vector<Symbol> symbols;
102   CodepointIndex match_offset = input.tokens[input.context_span.first].start;
103 
104   // Add start symbol.
105   if (input.context_span.first == 0 &&
106       nonterminals_->start_nt() != kUnassignedNonterm) {
107     match_offset = 0;
108     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
109         nonterminals_->start_nt(), CodepointSpan{0, 0},
110         /*match_offset=*/0, ParseTree::Type::kDefault));
111   }
112 
113   if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
114     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
115         nonterminals_->wordbreak_nt(),
116         CodepointSpan{match_offset, match_offset},
117         /*match_offset=*/match_offset, ParseTree::Type::kDefault));
118   }
119 
120   // Add symbols from tokens.
121   for (int i = input.context_span.first; i < input.context_span.second; i++) {
122     const Token& token = input.tokens[i];
123     lexer_.AppendTokenSymbols(token.value, /*match_offset=*/match_offset,
124                               CodepointSpan{token.start, token.end}, &symbols);
125     match_offset = token.end;
126 
127     // Add word break symbol.
128     if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
129       symbols.emplace_back(arena->AllocAndInit<ParseTree>(
130           nonterminals_->wordbreak_nt(),
131           CodepointSpan{match_offset, match_offset},
132           /*match_offset=*/match_offset, ParseTree::Type::kDefault));
133     }
134   }
135 
136   // Add end symbol if used by the grammar.
137   if (input.context_span.second == input.tokens.size() &&
138       nonterminals_->end_nt() != kUnassignedNonterm) {
139     symbols.emplace_back(arena->AllocAndInit<ParseTree>(
140         nonterminals_->end_nt(), CodepointSpan{match_offset, match_offset},
141         /*match_offset=*/match_offset, ParseTree::Type::kDefault));
142   }
143 
144   // Add symbols from the regex annotators.
145   const CodepointIndex context_start =
146       input.tokens[input.context_span.first].start;
147   const CodepointIndex context_end =
148       input.tokens[input.context_span.second - 1].end;
149   for (const RegexAnnotator& regex_annotator : regex_annotators_) {
150     std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
151         regex_annotator.pattern->Matcher(UnicodeText::Substring(
152             input.text, context_start, context_end, /*do_copy=*/false));
153     int status = UniLib::RegexMatcher::kNoError;
154     while (regex_matcher->Find(&status) &&
155            status == UniLib::RegexMatcher::kNoError) {
156       const CodepointSpan span{regex_matcher->Start(0, &status) + context_start,
157                                regex_matcher->End(0, &status) + context_start};
158       symbols.emplace_back(arena->AllocAndInit<ParseTree>(
159           regex_annotator.nonterm, span, /*match_offset=*/
160           MapCodepointToTokenPaddingIfPresent(token_match_start, span.first),
161           ParseTree::Type::kDefault));
162     }
163   }
164 
165   // Add symbols based on annotations.
166   if (auto annotation_nonterminals = nonterminals_->annotation_nt()) {
167     for (const AnnotatedSpan& annotated_span : input.annotations) {
168       const ClassificationResult& classification =
169           annotated_span.classification.front();
170       if (auto entry = annotation_nonterminals->LookupByKey(
171               classification.collection.c_str())) {
172         symbols.emplace_back(arena->AllocAndInit<AnnotationNode>(
173             entry->value(), annotated_span.span, /*match_offset=*/
174             MapCodepointToTokenPaddingIfPresent(token_match_start,
175                                                 annotated_span.span.first),
176             &classification));
177       }
178     }
179   }
180 
181   std::stable_sort(
182       symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) {
183         // Sort by increasing (end, start) position to guarantee the
184         // matcher requirement that the tokens are fed in non-decreasing
185         // end position order.
186         return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
187                std::tie(b.codepoint_span.second, b.codepoint_span.first);
188       });
189 
190   return symbols;
191 }
192 
EmitSymbol(const Symbol & symbol,UnsafeArena * arena,Matcher * matcher) const193 void Parser::EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
194                         Matcher* matcher) const {
195   if (!CheckMemoryUsage(arena)) {
196     return;
197   }
198   switch (symbol.type) {
199     case Symbol::Type::TYPE_PARSE_TREE: {
200       // Just emit the parse tree.
201       matcher->AddParseTree(symbol.parse_tree);
202       return;
203     }
204     case Symbol::Type::TYPE_DIGITS: {
205       // Emit <digits> if used by the rules.
206       if (nonterminals_->digits_nt() != kUnassignedNonterm) {
207         matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
208             nonterminals_->digits_nt(), symbol.codepoint_span,
209             symbol.match_offset, ParseTree::Type::kDefault));
210       }
211 
212       // Emit <n_digits> if used by the rules.
213       if (nonterminals_->n_digits_nt() != nullptr) {
214         const int num_digits =
215             symbol.codepoint_span.second - symbol.codepoint_span.first;
216         if (num_digits <= nonterminals_->n_digits_nt()->size()) {
217           const Nonterm n_digits_nt =
218               nonterminals_->n_digits_nt()->Get(num_digits - 1);
219           if (n_digits_nt != kUnassignedNonterm) {
220             matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
221                 nonterminals_->n_digits_nt()->Get(num_digits - 1),
222                 symbol.codepoint_span, symbol.match_offset,
223                 ParseTree::Type::kDefault));
224           }
225         }
226       }
227       break;
228     }
229     case Symbol::Type::TYPE_TERM: {
230       // Emit <uppercase_token> if used by the rules.
231       if (nonterminals_->uppercase_token_nt() != 0 &&
232           unilib_.IsUpperText(
233               UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
234         matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
235             nonterminals_->uppercase_token_nt(), symbol.codepoint_span,
236             symbol.match_offset, ParseTree::Type::kDefault));
237       }
238       break;
239     }
240     default:
241       break;
242   }
243 
244   // Emit the token as terminal.
245   matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
246                        symbol.lexeme);
247 
248   // Emit <token> if used by rules.
249   matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
250       nonterminals_->token_nt(), symbol.codepoint_span, symbol.match_offset,
251       ParseTree::Type::kDefault));
252 }
253 
254 // Parses an input text and returns the root rule derivations.
Parse(const TextContext & input,UnsafeArena * arena) const255 std::vector<Derivation> Parser::Parse(const TextContext& input,
256                                       UnsafeArena* arena) const {
257   // Check the tokens, input can be non-empty (whitespace) but have no tokens.
258   if (input.tokens.empty()) {
259     return {};
260   }
261 
262   // Select locale matching rules.
263   std::vector<const RulesSet_::Rules*> locale_rules =
264       SelectLocaleMatchingShards(rules_, rules_locales_, input.locales);
265 
266   if (locale_rules.empty()) {
267     // Nothing to do.
268     return {};
269   }
270 
271   Matcher matcher(&unilib_, rules_, locale_rules, arena);
272   for (const Symbol& symbol : SortedSymbolsForInput(input, arena)) {
273     EmitSymbol(symbol, arena, &matcher);
274   }
275   matcher.Finish();
276   return matcher.chart().derivations();
277 }
278 
279 }  // namespace libtextclassifier3::grammar
280