xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/utils/ir.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker #include <unordered_map>
22*993b0882SAndroid Build Coastguard Worker #include <unordered_set>
23*993b0882SAndroid Build Coastguard Worker #include <vector>
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker #include "utils/base/integral_types.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/rules_generated.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/types.h"
28*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/utils/locale-shard-map.h"
29*993b0882SAndroid Build Coastguard Worker 
30*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3::grammar {
31*993b0882SAndroid Build Coastguard Worker 
32*993b0882SAndroid Build Coastguard Worker // Pre-defined nonterminal classes that the lexer can handle.
33*993b0882SAndroid Build Coastguard Worker constexpr const char* kStartNonterm = "<^>";
34*993b0882SAndroid Build Coastguard Worker constexpr const char* kEndNonterm = "<$>";
35*993b0882SAndroid Build Coastguard Worker constexpr const char* kWordBreakNonterm = "<\b>";
36*993b0882SAndroid Build Coastguard Worker constexpr const char* kTokenNonterm = "<token>";
37*993b0882SAndroid Build Coastguard Worker constexpr const char* kUppercaseTokenNonterm = "<uppercase_token>";
38*993b0882SAndroid Build Coastguard Worker constexpr const char* kDigitsNonterm = "<digits>";
39*993b0882SAndroid Build Coastguard Worker constexpr const char* kNDigitsNonterm = "<%d_digits>";
40*993b0882SAndroid Build Coastguard Worker constexpr const int kMaxNDigitsNontermLength = 20;
41*993b0882SAndroid Build Coastguard Worker 
42*993b0882SAndroid Build Coastguard Worker // Low-level intermediate rules representation.
43*993b0882SAndroid Build Coastguard Worker // In this representation, nonterminals are specified simply as integers
44*993b0882SAndroid Build Coastguard Worker // (Nonterms), rather than strings which is more efficient.
45*993b0882SAndroid Build Coastguard Worker // Rule set optimizations are done on this representation.
46*993b0882SAndroid Build Coastguard Worker //
47*993b0882SAndroid Build Coastguard Worker // Rules are represented in (mostly) Chomsky Normal Form, where all rules are
48*993b0882SAndroid Build Coastguard Worker // of the following form, either:
49*993b0882SAndroid Build Coastguard Worker //   * <nonterm> ::= term
50*993b0882SAndroid Build Coastguard Worker //   * <nonterm> ::= <nonterm>
51*993b0882SAndroid Build Coastguard Worker //   * <nonterm> ::= <nonterm> <nonterm>
52*993b0882SAndroid Build Coastguard Worker class Ir {
53*993b0882SAndroid Build Coastguard Worker  public:
54*993b0882SAndroid Build Coastguard Worker   // A rule callback as a callback id and parameter pair.
55*993b0882SAndroid Build Coastguard Worker   struct Callback {
56*993b0882SAndroid Build Coastguard Worker     bool operator==(const Callback& other) const {
57*993b0882SAndroid Build Coastguard Worker       return std::tie(id, param) == std::tie(other.id, other.param);
58*993b0882SAndroid Build Coastguard Worker     }
59*993b0882SAndroid Build Coastguard Worker 
60*993b0882SAndroid Build Coastguard Worker     CallbackId id = kNoCallback;
61*993b0882SAndroid Build Coastguard Worker     int64 param = 0;
62*993b0882SAndroid Build Coastguard Worker   };
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker   // Constraints for triggering a rule.
65*993b0882SAndroid Build Coastguard Worker   struct Preconditions {
66*993b0882SAndroid Build Coastguard Worker     bool operator==(const Preconditions& other) const {
67*993b0882SAndroid Build Coastguard Worker       return max_whitespace_gap == other.max_whitespace_gap;
68*993b0882SAndroid Build Coastguard Worker     }
69*993b0882SAndroid Build Coastguard Worker 
70*993b0882SAndroid Build Coastguard Worker     // The maximum allowed whitespace between parts of the rule.
71*993b0882SAndroid Build Coastguard Worker     // The default of -1 allows for unbounded whitespace.
72*993b0882SAndroid Build Coastguard Worker     int8 max_whitespace_gap = -1;
73*993b0882SAndroid Build Coastguard Worker   };
74*993b0882SAndroid Build Coastguard Worker 
75*993b0882SAndroid Build Coastguard Worker   struct Lhs {
76*993b0882SAndroid Build Coastguard Worker     bool operator==(const Lhs& other) const {
77*993b0882SAndroid Build Coastguard Worker       return std::tie(nonterminal, callback, preconditions) ==
78*993b0882SAndroid Build Coastguard Worker              std::tie(other.nonterminal, other.callback, other.preconditions);
79*993b0882SAndroid Build Coastguard Worker     }
80*993b0882SAndroid Build Coastguard Worker 
81*993b0882SAndroid Build Coastguard Worker     Nonterm nonterminal = kUnassignedNonterm;
82*993b0882SAndroid Build Coastguard Worker     Callback callback;
83*993b0882SAndroid Build Coastguard Worker     Preconditions preconditions;
84*993b0882SAndroid Build Coastguard Worker   };
85*993b0882SAndroid Build Coastguard Worker   using LhsSet = std::vector<Lhs>;
86*993b0882SAndroid Build Coastguard Worker 
87*993b0882SAndroid Build Coastguard Worker   // A rules shard.
88*993b0882SAndroid Build Coastguard Worker   struct RulesShard {
89*993b0882SAndroid Build Coastguard Worker     // Terminal rules.
90*993b0882SAndroid Build Coastguard Worker     std::unordered_map<std::string, LhsSet> terminal_rules;
91*993b0882SAndroid Build Coastguard Worker     std::unordered_map<std::string, LhsSet> lowercase_terminal_rules;
92*993b0882SAndroid Build Coastguard Worker 
93*993b0882SAndroid Build Coastguard Worker     // Unary rules.
94*993b0882SAndroid Build Coastguard Worker     std::unordered_map<Nonterm, LhsSet> unary_rules;
95*993b0882SAndroid Build Coastguard Worker 
96*993b0882SAndroid Build Coastguard Worker     // Binary rules.
97*993b0882SAndroid Build Coastguard Worker     std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
98*993b0882SAndroid Build Coastguard Worker   };
99*993b0882SAndroid Build Coastguard Worker 
Ir(const LocaleShardMap & locale_shard_map)100*993b0882SAndroid Build Coastguard Worker   explicit Ir(const LocaleShardMap& locale_shard_map)
101*993b0882SAndroid Build Coastguard Worker       : num_nonterminals_(0),
102*993b0882SAndroid Build Coastguard Worker         locale_shard_map_(locale_shard_map),
103*993b0882SAndroid Build Coastguard Worker         shards_(locale_shard_map_.GetNumberOfShards()) {}
104*993b0882SAndroid Build Coastguard Worker 
105*993b0882SAndroid Build Coastguard Worker   // Adds a new non-terminal.
106*993b0882SAndroid Build Coastguard Worker   Nonterm AddNonterminal(const std::string& name = "") {
107*993b0882SAndroid Build Coastguard Worker     const Nonterm nonterminal = ++num_nonterminals_;
108*993b0882SAndroid Build Coastguard Worker     if (!name.empty()) {
109*993b0882SAndroid Build Coastguard Worker       // Record debug information.
110*993b0882SAndroid Build Coastguard Worker       SetNonterminal(name, nonterminal);
111*993b0882SAndroid Build Coastguard Worker     }
112*993b0882SAndroid Build Coastguard Worker     return nonterminal;
113*993b0882SAndroid Build Coastguard Worker   }
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker   // Sets the name of a nonterminal.
SetNonterminal(const std::string & name,const Nonterm nonterminal)116*993b0882SAndroid Build Coastguard Worker   void SetNonterminal(const std::string& name, const Nonterm nonterminal) {
117*993b0882SAndroid Build Coastguard Worker     nonterminal_names_[nonterminal] = name;
118*993b0882SAndroid Build Coastguard Worker     nonterminal_ids_[name] = nonterminal;
119*993b0882SAndroid Build Coastguard Worker   }
120*993b0882SAndroid Build Coastguard Worker 
121*993b0882SAndroid Build Coastguard Worker   // Defines a nonterminal if not yet defined.
DefineNonterminal(Nonterm nonterminal)122*993b0882SAndroid Build Coastguard Worker   Nonterm DefineNonterminal(Nonterm nonterminal) {
123*993b0882SAndroid Build Coastguard Worker     return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal();
124*993b0882SAndroid Build Coastguard Worker   }
125*993b0882SAndroid Build Coastguard Worker 
126*993b0882SAndroid Build Coastguard Worker   // Defines a new non-terminal that cannot be shared internally.
127*993b0882SAndroid Build Coastguard Worker   Nonterm AddUnshareableNonterminal(const std::string& name = "") {
128*993b0882SAndroid Build Coastguard Worker     const Nonterm nonterminal = AddNonterminal(name);
129*993b0882SAndroid Build Coastguard Worker     nonshareable_.insert(nonterminal);
130*993b0882SAndroid Build Coastguard Worker     return nonterminal;
131*993b0882SAndroid Build Coastguard Worker   }
132*993b0882SAndroid Build Coastguard Worker 
133*993b0882SAndroid Build Coastguard Worker   // Gets the non-terminal for a given name, if it was previously defined.
GetNonterminalForName(const std::string & name)134*993b0882SAndroid Build Coastguard Worker   Nonterm GetNonterminalForName(const std::string& name) const {
135*993b0882SAndroid Build Coastguard Worker     const auto it = nonterminal_ids_.find(name);
136*993b0882SAndroid Build Coastguard Worker     if (it == nonterminal_ids_.end()) {
137*993b0882SAndroid Build Coastguard Worker       return kUnassignedNonterm;
138*993b0882SAndroid Build Coastguard Worker     }
139*993b0882SAndroid Build Coastguard Worker     return it->second;
140*993b0882SAndroid Build Coastguard Worker   }
141*993b0882SAndroid Build Coastguard Worker 
142*993b0882SAndroid Build Coastguard Worker   // Adds a terminal rule <lhs> ::= terminal.
143*993b0882SAndroid Build Coastguard Worker   Nonterm Add(const Lhs& lhs, const std::string& terminal,
144*993b0882SAndroid Build Coastguard Worker               bool case_sensitive = false, int shard = 0);
145*993b0882SAndroid Build Coastguard Worker   Nonterm Add(const Nonterm lhs, const std::string& terminal,
146*993b0882SAndroid Build Coastguard Worker               bool case_sensitive = false, int shard = 0) {
147*993b0882SAndroid Build Coastguard Worker     return Add(Lhs{lhs}, terminal, case_sensitive, shard);
148*993b0882SAndroid Build Coastguard Worker   }
149*993b0882SAndroid Build Coastguard Worker 
150*993b0882SAndroid Build Coastguard Worker   // Adds a unary rule <lhs> ::= <rhs>.
151*993b0882SAndroid Build Coastguard Worker   Nonterm Add(const Lhs& lhs, Nonterm rhs, int shard = 0) {
152*993b0882SAndroid Build Coastguard Worker     return AddRule(lhs, rhs, &shards_[shard].unary_rules);
153*993b0882SAndroid Build Coastguard Worker   }
154*993b0882SAndroid Build Coastguard Worker   Nonterm Add(Nonterm lhs, Nonterm rhs, int shard = 0) {
155*993b0882SAndroid Build Coastguard Worker     return Add(Lhs{lhs}, rhs, shard);
156*993b0882SAndroid Build Coastguard Worker   }
157*993b0882SAndroid Build Coastguard Worker 
158*993b0882SAndroid Build Coastguard Worker   // Adds a binary rule <lhs> ::= <rhs_1> <rhs_2>.
159*993b0882SAndroid Build Coastguard Worker   Nonterm Add(const Lhs& lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
160*993b0882SAndroid Build Coastguard Worker     return AddRule(lhs, {rhs_1, rhs_2}, &shards_[shard].binary_rules);
161*993b0882SAndroid Build Coastguard Worker   }
162*993b0882SAndroid Build Coastguard Worker   Nonterm Add(Nonterm lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
163*993b0882SAndroid Build Coastguard Worker     return Add(Lhs{lhs}, rhs_1, rhs_2, shard);
164*993b0882SAndroid Build Coastguard Worker   }
165*993b0882SAndroid Build Coastguard Worker 
166*993b0882SAndroid Build Coastguard Worker   // Adds a rule <lhs> ::= <rhs_1> <rhs_2> ... <rhs_k>
167*993b0882SAndroid Build Coastguard Worker   //
168*993b0882SAndroid Build Coastguard Worker   // If k > 2, we internally create a series of Nonterms representing prefixes
169*993b0882SAndroid Build Coastguard Worker   // of the full rhs.
170*993b0882SAndroid Build Coastguard Worker   //     <temp_1> ::= <RHS_1> <RHS_2>
171*993b0882SAndroid Build Coastguard Worker   //     <temp_2> ::= <temp_1> <RHS_3>
172*993b0882SAndroid Build Coastguard Worker   //     ...
173*993b0882SAndroid Build Coastguard Worker   //     <LHS> ::= <temp_(k-1)> <RHS_k>
174*993b0882SAndroid Build Coastguard Worker   Nonterm Add(const Lhs& lhs, const std::vector<Nonterm>& rhs, int shard = 0);
175*993b0882SAndroid Build Coastguard Worker   Nonterm Add(Nonterm lhs, const std::vector<Nonterm>& rhs, int shard = 0) {
176*993b0882SAndroid Build Coastguard Worker     return Add(Lhs{lhs}, rhs, shard);
177*993b0882SAndroid Build Coastguard Worker   }
178*993b0882SAndroid Build Coastguard Worker 
179*993b0882SAndroid Build Coastguard Worker   // Adds a regex rule <lhs> ::= <regex_pattern>.
180*993b0882SAndroid Build Coastguard Worker   Nonterm AddRegex(Nonterm lhs, const std::string& regex_pattern);
181*993b0882SAndroid Build Coastguard Worker 
182*993b0882SAndroid Build Coastguard Worker   // Adds a definition for a nonterminal provided by a text annotation.
183*993b0882SAndroid Build Coastguard Worker   void AddAnnotation(Nonterm lhs, const std::string& annotation);
184*993b0882SAndroid Build Coastguard Worker 
185*993b0882SAndroid Build Coastguard Worker   // Serializes a rule set in the intermediate representation into the
186*993b0882SAndroid Build Coastguard Worker   // memory mappable inference format.
187*993b0882SAndroid Build Coastguard Worker   void Serialize(bool include_debug_information, RulesSetT* output) const;
188*993b0882SAndroid Build Coastguard Worker 
189*993b0882SAndroid Build Coastguard Worker   std::string SerializeAsFlatbuffer(
190*993b0882SAndroid Build Coastguard Worker       bool include_debug_information = false) const;
191*993b0882SAndroid Build Coastguard Worker 
shards()192*993b0882SAndroid Build Coastguard Worker   const std::vector<RulesShard>& shards() const { return shards_; }
regex_rules()193*993b0882SAndroid Build Coastguard Worker   const std::vector<std::pair<std::string, Nonterm>>& regex_rules() const {
194*993b0882SAndroid Build Coastguard Worker     return regex_rules_;
195*993b0882SAndroid Build Coastguard Worker   }
annotations()196*993b0882SAndroid Build Coastguard Worker   const std::vector<std::pair<std::string, Nonterm>>& annotations() const {
197*993b0882SAndroid Build Coastguard Worker     return annotations_;
198*993b0882SAndroid Build Coastguard Worker   }
199*993b0882SAndroid Build Coastguard Worker 
200*993b0882SAndroid Build Coastguard Worker  private:
201*993b0882SAndroid Build Coastguard Worker   template <typename R, typename H>
AddRule(const Lhs & lhs,const R & rhs,std::unordered_map<R,LhsSet,H> * rules)202*993b0882SAndroid Build Coastguard Worker   Nonterm AddRule(const Lhs& lhs, const R& rhs,
203*993b0882SAndroid Build Coastguard Worker                   std::unordered_map<R, LhsSet, H>* rules) {
204*993b0882SAndroid Build Coastguard Worker     const auto it = rules->find(rhs);
205*993b0882SAndroid Build Coastguard Worker 
206*993b0882SAndroid Build Coastguard Worker     // Rhs was not yet used.
207*993b0882SAndroid Build Coastguard Worker     if (it == rules->end()) {
208*993b0882SAndroid Build Coastguard Worker       const Nonterm nonterminal = DefineNonterminal(lhs.nonterminal);
209*993b0882SAndroid Build Coastguard Worker       rules->insert(it,
210*993b0882SAndroid Build Coastguard Worker                     {rhs, {Lhs{nonterminal, lhs.callback, lhs.preconditions}}});
211*993b0882SAndroid Build Coastguard Worker       return nonterminal;
212*993b0882SAndroid Build Coastguard Worker     }
213*993b0882SAndroid Build Coastguard Worker 
214*993b0882SAndroid Build Coastguard Worker     return AddToSet(lhs, &it->second);
215*993b0882SAndroid Build Coastguard Worker   }
216*993b0882SAndroid Build Coastguard Worker 
217*993b0882SAndroid Build Coastguard Worker   // Adds a new callback to an lhs set, potentially sharing nonterminal ids and
218*993b0882SAndroid Build Coastguard Worker   // existing callbacks.
219*993b0882SAndroid Build Coastguard Worker   Nonterm AddToSet(const Lhs& lhs, LhsSet* lhs_set);
220*993b0882SAndroid Build Coastguard Worker 
221*993b0882SAndroid Build Coastguard Worker   // Serializes the sharded terminal rules.
222*993b0882SAndroid Build Coastguard Worker   void SerializeTerminalRules(
223*993b0882SAndroid Build Coastguard Worker       RulesSetT* rules_set,
224*993b0882SAndroid Build Coastguard Worker       std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const;
225*993b0882SAndroid Build Coastguard Worker 
226*993b0882SAndroid Build Coastguard Worker   // The defined non-terminals.
227*993b0882SAndroid Build Coastguard Worker   Nonterm num_nonterminals_;
228*993b0882SAndroid Build Coastguard Worker   std::unordered_set<Nonterm> nonshareable_;
229*993b0882SAndroid Build Coastguard Worker 
230*993b0882SAndroid Build Coastguard Worker   // Locale information for Rules
231*993b0882SAndroid Build Coastguard Worker   const LocaleShardMap& locale_shard_map_;
232*993b0882SAndroid Build Coastguard Worker   // The sharded rules.
233*993b0882SAndroid Build Coastguard Worker   std::vector<RulesShard> shards_;
234*993b0882SAndroid Build Coastguard Worker 
235*993b0882SAndroid Build Coastguard Worker   // The regex rules.
236*993b0882SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, Nonterm>> regex_rules_;
237*993b0882SAndroid Build Coastguard Worker 
238*993b0882SAndroid Build Coastguard Worker   // Mapping from annotation name to nonterminal.
239*993b0882SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, Nonterm>> annotations_;
240*993b0882SAndroid Build Coastguard Worker 
241*993b0882SAndroid Build Coastguard Worker   // Debug information.
242*993b0882SAndroid Build Coastguard Worker   std::unordered_map<Nonterm, std::string> nonterminal_names_;
243*993b0882SAndroid Build Coastguard Worker   std::unordered_map<std::string, Nonterm> nonterminal_ids_;
244*993b0882SAndroid Build Coastguard Worker };
245*993b0882SAndroid Build Coastguard Worker 
246*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3::grammar
247*993b0882SAndroid Build Coastguard Worker 
248*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
249