xref: /aosp_15_r20/external/libtextclassifier/native/annotator/annotator.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "annotator/annotator.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker #include <cmath>
21*993b0882SAndroid Build Coastguard Worker #include <cstddef>
22*993b0882SAndroid Build Coastguard Worker #include <iterator>
23*993b0882SAndroid Build Coastguard Worker #include <limits>
24*993b0882SAndroid Build Coastguard Worker #include <numeric>
25*993b0882SAndroid Build Coastguard Worker #include <string>
26*993b0882SAndroid Build Coastguard Worker #include <unordered_map>
27*993b0882SAndroid Build Coastguard Worker #include <vector>
28*993b0882SAndroid Build Coastguard Worker 
29*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
30*993b0882SAndroid Build Coastguard Worker #include "annotator/datetime/grammar-parser.h"
31*993b0882SAndroid Build Coastguard Worker #include "annotator/datetime/regex-parser.h"
32*993b0882SAndroid Build Coastguard Worker #include "annotator/flatbuffer-utils.h"
33*993b0882SAndroid Build Coastguard Worker #include "annotator/knowledge/knowledge-engine-types.h"
34*993b0882SAndroid Build Coastguard Worker #include "annotator/model_generated.h"
35*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
36*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
37*993b0882SAndroid Build Coastguard Worker #include "utils/base/status.h"
38*993b0882SAndroid Build Coastguard Worker #include "utils/base/statusor.h"
39*993b0882SAndroid Build Coastguard Worker #include "utils/calendar/calendar.h"
40*993b0882SAndroid Build Coastguard Worker #include "utils/checksum.h"
41*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/analyzer.h"
42*993b0882SAndroid Build Coastguard Worker #include "utils/i18n/locale-list.h"
43*993b0882SAndroid Build Coastguard Worker #include "utils/i18n/locale.h"
44*993b0882SAndroid Build Coastguard Worker #include "utils/math/softmax.h"
45*993b0882SAndroid Build Coastguard Worker #include "utils/normalization.h"
46*993b0882SAndroid Build Coastguard Worker #include "utils/optional.h"
47*993b0882SAndroid Build Coastguard Worker #include "utils/regex-match.h"
48*993b0882SAndroid Build Coastguard Worker #include "utils/strings/append.h"
49*993b0882SAndroid Build Coastguard Worker #include "utils/strings/numbers.h"
50*993b0882SAndroid Build Coastguard Worker #include "utils/strings/split.h"
51*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
52*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib-common.h"
53*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib_regex.h"
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker const std::string& Annotator::kPhoneCollection =
__anon35a9cf7c0102() 60*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("phone"); }();
61*993b0882SAndroid Build Coastguard Worker const std::string& Annotator::kAddressCollection =
__anon35a9cf7c0202() 62*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("address"); }();
63*993b0882SAndroid Build Coastguard Worker const std::string& Annotator::kDateCollection =
__anon35a9cf7c0302() 64*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("date"); }();
65*993b0882SAndroid Build Coastguard Worker const std::string& Annotator::kUrlCollection =
__anon35a9cf7c0402() 66*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("url"); }();
67*993b0882SAndroid Build Coastguard Worker const std::string& Annotator::kEmailCollection =
__anon35a9cf7c0502() 68*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("email"); }();
69*993b0882SAndroid Build Coastguard Worker 
70*993b0882SAndroid Build Coastguard Worker namespace {
LoadAndVerifyModel(const void * addr,int size)71*993b0882SAndroid Build Coastguard Worker const Model* LoadAndVerifyModel(const void* addr, int size) {
72*993b0882SAndroid Build Coastguard Worker   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73*993b0882SAndroid Build Coastguard Worker   if (VerifyModelBuffer(verifier)) {
74*993b0882SAndroid Build Coastguard Worker     return GetModel(addr);
75*993b0882SAndroid Build Coastguard Worker   } else {
76*993b0882SAndroid Build Coastguard Worker     return nullptr;
77*993b0882SAndroid Build Coastguard Worker   }
78*993b0882SAndroid Build Coastguard Worker }
79*993b0882SAndroid Build Coastguard Worker 
LoadAndVerifyPersonNameModel(const void * addr,int size)80*993b0882SAndroid Build Coastguard Worker const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81*993b0882SAndroid Build Coastguard Worker                                                     int size) {
82*993b0882SAndroid Build Coastguard Worker   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83*993b0882SAndroid Build Coastguard Worker   if (VerifyPersonNameModelBuffer(verifier)) {
84*993b0882SAndroid Build Coastguard Worker     return GetPersonNameModel(addr);
85*993b0882SAndroid Build Coastguard Worker   } else {
86*993b0882SAndroid Build Coastguard Worker     return nullptr;
87*993b0882SAndroid Build Coastguard Worker   }
88*993b0882SAndroid Build Coastguard Worker }
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91*993b0882SAndroid Build Coastguard Worker // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92*993b0882SAndroid Build Coastguard Worker const UniLib* MaybeCreateUnilib(const UniLib* lib,
93*993b0882SAndroid Build Coastguard Worker                                 std::unique_ptr<UniLib>* owned_lib) {
94*993b0882SAndroid Build Coastguard Worker   if (lib) {
95*993b0882SAndroid Build Coastguard Worker     return lib;
96*993b0882SAndroid Build Coastguard Worker   } else {
97*993b0882SAndroid Build Coastguard Worker     owned_lib->reset(new UniLib);
98*993b0882SAndroid Build Coastguard Worker     return owned_lib->get();
99*993b0882SAndroid Build Coastguard Worker   }
100*993b0882SAndroid Build Coastguard Worker }
101*993b0882SAndroid Build Coastguard Worker 
102*993b0882SAndroid Build Coastguard Worker // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103*993b0882SAndroid Build Coastguard Worker const CalendarLib* MaybeCreateCalendarlib(
104*993b0882SAndroid Build Coastguard Worker     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105*993b0882SAndroid Build Coastguard Worker   if (lib) {
106*993b0882SAndroid Build Coastguard Worker     return lib;
107*993b0882SAndroid Build Coastguard Worker   } else {
108*993b0882SAndroid Build Coastguard Worker     owned_lib->reset(new CalendarLib);
109*993b0882SAndroid Build Coastguard Worker     return owned_lib->get();
110*993b0882SAndroid Build Coastguard Worker   }
111*993b0882SAndroid Build Coastguard Worker }
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker // Returns whether the provided input is valid:
114*993b0882SAndroid Build Coastguard Worker //   * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115*993b0882SAndroid Build Coastguard Worker bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116*993b0882SAndroid Build Coastguard Worker   return (span.first >= 0 && span.first < span.second &&
117*993b0882SAndroid Build Coastguard Worker           span.second <= context.size_codepoints());
118*993b0882SAndroid Build Coastguard Worker }
119*993b0882SAndroid Build Coastguard Worker 
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120*993b0882SAndroid Build Coastguard Worker std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121*993b0882SAndroid Build Coastguard Worker     const flatbuffers::Vector<int32_t>* ints) {
122*993b0882SAndroid Build Coastguard Worker   if (ints == nullptr) {
123*993b0882SAndroid Build Coastguard Worker     return {};
124*993b0882SAndroid Build Coastguard Worker   }
125*993b0882SAndroid Build Coastguard Worker   std::unordered_set<char32> ints_set;
126*993b0882SAndroid Build Coastguard Worker   for (auto value : *ints) {
127*993b0882SAndroid Build Coastguard Worker     ints_set.insert(static_cast<char32>(value));
128*993b0882SAndroid Build Coastguard Worker   }
129*993b0882SAndroid Build Coastguard Worker   return ints_set;
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker 
132*993b0882SAndroid Build Coastguard Worker }  // namespace
133*993b0882SAndroid Build Coastguard Worker 
SelectionInterpreter()134*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135*993b0882SAndroid Build Coastguard Worker   if (!selection_interpreter_) {
136*993b0882SAndroid Build Coastguard Worker     TC3_CHECK(selection_executor_);
137*993b0882SAndroid Build Coastguard Worker     selection_interpreter_ = selection_executor_->CreateInterpreter();
138*993b0882SAndroid Build Coastguard Worker     if (!selection_interpreter_) {
139*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140*993b0882SAndroid Build Coastguard Worker     }
141*993b0882SAndroid Build Coastguard Worker   }
142*993b0882SAndroid Build Coastguard Worker   return selection_interpreter_.get();
143*993b0882SAndroid Build Coastguard Worker }
144*993b0882SAndroid Build Coastguard Worker 
ClassificationInterpreter()145*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146*993b0882SAndroid Build Coastguard Worker   if (!classification_interpreter_) {
147*993b0882SAndroid Build Coastguard Worker     TC3_CHECK(classification_executor_);
148*993b0882SAndroid Build Coastguard Worker     classification_interpreter_ = classification_executor_->CreateInterpreter();
149*993b0882SAndroid Build Coastguard Worker     if (!classification_interpreter_) {
150*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151*993b0882SAndroid Build Coastguard Worker     }
152*993b0882SAndroid Build Coastguard Worker   }
153*993b0882SAndroid Build Coastguard Worker   return classification_interpreter_.get();
154*993b0882SAndroid Build Coastguard Worker }
155*993b0882SAndroid Build Coastguard Worker 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157*993b0882SAndroid Build Coastguard Worker     const char* buffer, int size, const UniLib* unilib,
158*993b0882SAndroid Build Coastguard Worker     const CalendarLib* calendarlib) {
159*993b0882SAndroid Build Coastguard Worker   const Model* model = LoadAndVerifyModel(buffer, size);
160*993b0882SAndroid Build Coastguard Worker   if (model == nullptr) {
161*993b0882SAndroid Build Coastguard Worker     return nullptr;
162*993b0882SAndroid Build Coastguard Worker   }
163*993b0882SAndroid Build Coastguard Worker 
164*993b0882SAndroid Build Coastguard Worker   auto classifier = std::unique_ptr<Annotator>(new Annotator());
165*993b0882SAndroid Build Coastguard Worker   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166*993b0882SAndroid Build Coastguard Worker   calendarlib =
167*993b0882SAndroid Build Coastguard Worker       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168*993b0882SAndroid Build Coastguard Worker   classifier->ValidateAndInitialize(model, unilib, calendarlib);
169*993b0882SAndroid Build Coastguard Worker   if (!classifier->IsInitialized()) {
170*993b0882SAndroid Build Coastguard Worker     return nullptr;
171*993b0882SAndroid Build Coastguard Worker   }
172*993b0882SAndroid Build Coastguard Worker 
173*993b0882SAndroid Build Coastguard Worker   return classifier;
174*993b0882SAndroid Build Coastguard Worker }
175*993b0882SAndroid Build Coastguard Worker 
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromString(
177*993b0882SAndroid Build Coastguard Worker     const std::string& buffer, const UniLib* unilib,
178*993b0882SAndroid Build Coastguard Worker     const CalendarLib* calendarlib) {
179*993b0882SAndroid Build Coastguard Worker   auto classifier = std::unique_ptr<Annotator>(new Annotator());
180*993b0882SAndroid Build Coastguard Worker   classifier->owned_buffer_ = buffer;
181*993b0882SAndroid Build Coastguard Worker   const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182*993b0882SAndroid Build Coastguard Worker                                           classifier->owned_buffer_.size());
183*993b0882SAndroid Build Coastguard Worker   if (model == nullptr) {
184*993b0882SAndroid Build Coastguard Worker     return nullptr;
185*993b0882SAndroid Build Coastguard Worker   }
186*993b0882SAndroid Build Coastguard Worker   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187*993b0882SAndroid Build Coastguard Worker   calendarlib =
188*993b0882SAndroid Build Coastguard Worker       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189*993b0882SAndroid Build Coastguard Worker   classifier->ValidateAndInitialize(model, unilib, calendarlib);
190*993b0882SAndroid Build Coastguard Worker   if (!classifier->IsInitialized()) {
191*993b0882SAndroid Build Coastguard Worker     return nullptr;
192*993b0882SAndroid Build Coastguard Worker   }
193*993b0882SAndroid Build Coastguard Worker 
194*993b0882SAndroid Build Coastguard Worker   return classifier;
195*993b0882SAndroid Build Coastguard Worker }
196*993b0882SAndroid Build Coastguard Worker 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199*993b0882SAndroid Build Coastguard Worker     const CalendarLib* calendarlib) {
200*993b0882SAndroid Build Coastguard Worker   if (!(*mmap)->handle().ok()) {
201*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Mmap failed.";
202*993b0882SAndroid Build Coastguard Worker     return nullptr;
203*993b0882SAndroid Build Coastguard Worker   }
204*993b0882SAndroid Build Coastguard Worker 
205*993b0882SAndroid Build Coastguard Worker   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206*993b0882SAndroid Build Coastguard Worker                                           (*mmap)->handle().num_bytes());
207*993b0882SAndroid Build Coastguard Worker   if (!model) {
208*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Model verification failed.";
209*993b0882SAndroid Build Coastguard Worker     return nullptr;
210*993b0882SAndroid Build Coastguard Worker   }
211*993b0882SAndroid Build Coastguard Worker 
212*993b0882SAndroid Build Coastguard Worker   auto classifier = std::unique_ptr<Annotator>(new Annotator());
213*993b0882SAndroid Build Coastguard Worker   classifier->mmap_ = std::move(*mmap);
214*993b0882SAndroid Build Coastguard Worker   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215*993b0882SAndroid Build Coastguard Worker   calendarlib =
216*993b0882SAndroid Build Coastguard Worker       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217*993b0882SAndroid Build Coastguard Worker   classifier->ValidateAndInitialize(model, unilib, calendarlib);
218*993b0882SAndroid Build Coastguard Worker   if (!classifier->IsInitialized()) {
219*993b0882SAndroid Build Coastguard Worker     return nullptr;
220*993b0882SAndroid Build Coastguard Worker   }
221*993b0882SAndroid Build Coastguard Worker 
222*993b0882SAndroid Build Coastguard Worker   return classifier;
223*993b0882SAndroid Build Coastguard Worker }
224*993b0882SAndroid Build Coastguard Worker 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<CalendarLib> calendarlib) {
228*993b0882SAndroid Build Coastguard Worker   if (!(*mmap)->handle().ok()) {
229*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Mmap failed.";
230*993b0882SAndroid Build Coastguard Worker     return nullptr;
231*993b0882SAndroid Build Coastguard Worker   }
232*993b0882SAndroid Build Coastguard Worker 
233*993b0882SAndroid Build Coastguard Worker   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234*993b0882SAndroid Build Coastguard Worker                                           (*mmap)->handle().num_bytes());
235*993b0882SAndroid Build Coastguard Worker   if (model == nullptr) {
236*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Model verification failed.";
237*993b0882SAndroid Build Coastguard Worker     return nullptr;
238*993b0882SAndroid Build Coastguard Worker   }
239*993b0882SAndroid Build Coastguard Worker 
240*993b0882SAndroid Build Coastguard Worker   auto classifier = std::unique_ptr<Annotator>(new Annotator());
241*993b0882SAndroid Build Coastguard Worker   classifier->mmap_ = std::move(*mmap);
242*993b0882SAndroid Build Coastguard Worker   classifier->owned_unilib_ = std::move(unilib);
243*993b0882SAndroid Build Coastguard Worker   classifier->owned_calendarlib_ = std::move(calendarlib);
244*993b0882SAndroid Build Coastguard Worker   classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245*993b0882SAndroid Build Coastguard Worker                                     classifier->owned_calendarlib_.get());
246*993b0882SAndroid Build Coastguard Worker   if (!classifier->IsInitialized()) {
247*993b0882SAndroid Build Coastguard Worker     return nullptr;
248*993b0882SAndroid Build Coastguard Worker   }
249*993b0882SAndroid Build Coastguard Worker 
250*993b0882SAndroid Build Coastguard Worker   return classifier;
251*993b0882SAndroid Build Coastguard Worker }
252*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254*993b0882SAndroid Build Coastguard Worker     int fd, int offset, int size, const UniLib* unilib,
255*993b0882SAndroid Build Coastguard Worker     const CalendarLib* calendarlib) {
256*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, unilib, calendarlib);
258*993b0882SAndroid Build Coastguard Worker }
259*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261*993b0882SAndroid Build Coastguard Worker     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<CalendarLib> calendarlib) {
263*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265*993b0882SAndroid Build Coastguard Worker }
266*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268*993b0882SAndroid Build Coastguard Worker     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, unilib, calendarlib);
271*993b0882SAndroid Build Coastguard Worker }
272*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274*993b0882SAndroid Build Coastguard Worker     int fd, std::unique_ptr<UniLib> unilib,
275*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<CalendarLib> calendarlib) {
276*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278*993b0882SAndroid Build Coastguard Worker }
279*993b0882SAndroid Build Coastguard Worker 
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281*993b0882SAndroid Build Coastguard Worker                                                const UniLib* unilib,
282*993b0882SAndroid Build Coastguard Worker                                                const CalendarLib* calendarlib) {
283*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, unilib, calendarlib);
285*993b0882SAndroid Build Coastguard Worker }
286*993b0882SAndroid Build Coastguard Worker 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287*993b0882SAndroid Build Coastguard Worker std::unique_ptr<Annotator> Annotator::FromPath(
288*993b0882SAndroid Build Coastguard Worker     const std::string& path, std::unique_ptr<UniLib> unilib,
289*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<CalendarLib> calendarlib) {
290*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292*993b0882SAndroid Build Coastguard Worker }
293*993b0882SAndroid Build Coastguard Worker 
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294*993b0882SAndroid Build Coastguard Worker void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295*993b0882SAndroid Build Coastguard Worker                                       const CalendarLib* calendarlib) {
296*993b0882SAndroid Build Coastguard Worker   model_ = model;
297*993b0882SAndroid Build Coastguard Worker   unilib_ = unilib;
298*993b0882SAndroid Build Coastguard Worker   calendarlib_ = calendarlib;
299*993b0882SAndroid Build Coastguard Worker 
300*993b0882SAndroid Build Coastguard Worker   initialized_ = false;
301*993b0882SAndroid Build Coastguard Worker 
302*993b0882SAndroid Build Coastguard Worker   if (model_ == nullptr) {
303*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No model specified.";
304*993b0882SAndroid Build Coastguard Worker     return;
305*993b0882SAndroid Build Coastguard Worker   }
306*993b0882SAndroid Build Coastguard Worker 
307*993b0882SAndroid Build Coastguard Worker   const bool model_enabled_for_annotation =
308*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options() != nullptr &&
309*993b0882SAndroid Build Coastguard Worker        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310*993b0882SAndroid Build Coastguard Worker   const bool model_enabled_for_classification =
311*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options() != nullptr &&
312*993b0882SAndroid Build Coastguard Worker        (model_->triggering_options()->enabled_modes() &
313*993b0882SAndroid Build Coastguard Worker         ModeFlag_CLASSIFICATION));
314*993b0882SAndroid Build Coastguard Worker   const bool model_enabled_for_selection =
315*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options() != nullptr &&
316*993b0882SAndroid Build Coastguard Worker        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317*993b0882SAndroid Build Coastguard Worker 
318*993b0882SAndroid Build Coastguard Worker   // Annotation requires the selection model.
319*993b0882SAndroid Build Coastguard Worker   if (model_enabled_for_annotation || model_enabled_for_selection) {
320*993b0882SAndroid Build Coastguard Worker     if (!model_->selection_options()) {
321*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No selection options.";
322*993b0882SAndroid Build Coastguard Worker       return;
323*993b0882SAndroid Build Coastguard Worker     }
324*993b0882SAndroid Build Coastguard Worker     if (!model_->selection_feature_options()) {
325*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No selection feature options.";
326*993b0882SAndroid Build Coastguard Worker       return;
327*993b0882SAndroid Build Coastguard Worker     }
328*993b0882SAndroid Build Coastguard Worker     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330*993b0882SAndroid Build Coastguard Worker       return;
331*993b0882SAndroid Build Coastguard Worker     }
332*993b0882SAndroid Build Coastguard Worker     if (!model_->selection_model()) {
333*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No selection model.";
334*993b0882SAndroid Build Coastguard Worker       return;
335*993b0882SAndroid Build Coastguard Worker     }
336*993b0882SAndroid Build Coastguard Worker     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337*993b0882SAndroid Build Coastguard Worker     if (!selection_executor_) {
338*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize selection executor.";
339*993b0882SAndroid Build Coastguard Worker       return;
340*993b0882SAndroid Build Coastguard Worker     }
341*993b0882SAndroid Build Coastguard Worker   }
342*993b0882SAndroid Build Coastguard Worker 
343*993b0882SAndroid Build Coastguard Worker   // Even if the annotation mode is not enabled (for the neural network model),
344*993b0882SAndroid Build Coastguard Worker   // the selection feature processor is needed to tokenize the text for other
345*993b0882SAndroid Build Coastguard Worker   // models.
346*993b0882SAndroid Build Coastguard Worker   if (model_->selection_feature_options()) {
347*993b0882SAndroid Build Coastguard Worker     selection_feature_processor_.reset(
348*993b0882SAndroid Build Coastguard Worker         new FeatureProcessor(model_->selection_feature_options(), unilib_));
349*993b0882SAndroid Build Coastguard Worker   }
350*993b0882SAndroid Build Coastguard Worker 
351*993b0882SAndroid Build Coastguard Worker   // Annotation requires the classification model for conflict resolution and
352*993b0882SAndroid Build Coastguard Worker   // scoring.
353*993b0882SAndroid Build Coastguard Worker   // Selection requires the classification model for conflict resolution.
354*993b0882SAndroid Build Coastguard Worker   if (model_enabled_for_annotation || model_enabled_for_classification ||
355*993b0882SAndroid Build Coastguard Worker       model_enabled_for_selection) {
356*993b0882SAndroid Build Coastguard Worker     if (!model_->classification_options()) {
357*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No classification options.";
358*993b0882SAndroid Build Coastguard Worker       return;
359*993b0882SAndroid Build Coastguard Worker     }
360*993b0882SAndroid Build Coastguard Worker 
361*993b0882SAndroid Build Coastguard Worker     if (!model_->classification_feature_options()) {
362*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No classification feature options.";
363*993b0882SAndroid Build Coastguard Worker       return;
364*993b0882SAndroid Build Coastguard Worker     }
365*993b0882SAndroid Build Coastguard Worker 
366*993b0882SAndroid Build Coastguard Worker     if (!model_->classification_feature_options()
367*993b0882SAndroid Build Coastguard Worker              ->bounds_sensitive_features()) {
368*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
369*993b0882SAndroid Build Coastguard Worker       return;
370*993b0882SAndroid Build Coastguard Worker     }
371*993b0882SAndroid Build Coastguard Worker     if (!model_->classification_model()) {
372*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No clf model.";
373*993b0882SAndroid Build Coastguard Worker       return;
374*993b0882SAndroid Build Coastguard Worker     }
375*993b0882SAndroid Build Coastguard Worker 
376*993b0882SAndroid Build Coastguard Worker     classification_executor_ =
377*993b0882SAndroid Build Coastguard Worker         ModelExecutor::FromBuffer(model_->classification_model());
378*993b0882SAndroid Build Coastguard Worker     if (!classification_executor_) {
379*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize classification executor.";
380*993b0882SAndroid Build Coastguard Worker       return;
381*993b0882SAndroid Build Coastguard Worker     }
382*993b0882SAndroid Build Coastguard Worker 
383*993b0882SAndroid Build Coastguard Worker     classification_feature_processor_.reset(new FeatureProcessor(
384*993b0882SAndroid Build Coastguard Worker         model_->classification_feature_options(), unilib_));
385*993b0882SAndroid Build Coastguard Worker   }
386*993b0882SAndroid Build Coastguard Worker 
387*993b0882SAndroid Build Coastguard Worker   // The embeddings need to be specified if the model is to be used for
388*993b0882SAndroid Build Coastguard Worker   // classification or selection.
389*993b0882SAndroid Build Coastguard Worker   if (model_enabled_for_annotation || model_enabled_for_classification ||
390*993b0882SAndroid Build Coastguard Worker       model_enabled_for_selection) {
391*993b0882SAndroid Build Coastguard Worker     if (!model_->embedding_model()) {
392*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No embedding model.";
393*993b0882SAndroid Build Coastguard Worker       return;
394*993b0882SAndroid Build Coastguard Worker     }
395*993b0882SAndroid Build Coastguard Worker 
396*993b0882SAndroid Build Coastguard Worker     // Check that the embedding size of the selection and classification model
397*993b0882SAndroid Build Coastguard Worker     // matches, as they are using the same embeddings.
398*993b0882SAndroid Build Coastguard Worker     if (model_enabled_for_selection &&
399*993b0882SAndroid Build Coastguard Worker         (model_->selection_feature_options()->embedding_size() !=
400*993b0882SAndroid Build Coastguard Worker              model_->classification_feature_options()->embedding_size() ||
401*993b0882SAndroid Build Coastguard Worker          model_->selection_feature_options()->embedding_quantization_bits() !=
402*993b0882SAndroid Build Coastguard Worker              model_->classification_feature_options()
403*993b0882SAndroid Build Coastguard Worker                  ->embedding_quantization_bits())) {
404*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
405*993b0882SAndroid Build Coastguard Worker       return;
406*993b0882SAndroid Build Coastguard Worker     }
407*993b0882SAndroid Build Coastguard Worker 
408*993b0882SAndroid Build Coastguard Worker     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
409*993b0882SAndroid Build Coastguard Worker         model_->embedding_model(),
410*993b0882SAndroid Build Coastguard Worker         model_->classification_feature_options()->embedding_size(),
411*993b0882SAndroid Build Coastguard Worker         model_->classification_feature_options()->embedding_quantization_bits(),
412*993b0882SAndroid Build Coastguard Worker         model_->embedding_pruning_mask());
413*993b0882SAndroid Build Coastguard Worker     if (!embedding_executor_) {
414*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
415*993b0882SAndroid Build Coastguard Worker       return;
416*993b0882SAndroid Build Coastguard Worker     }
417*993b0882SAndroid Build Coastguard Worker   }
418*993b0882SAndroid Build Coastguard Worker 
419*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
420*993b0882SAndroid Build Coastguard Worker   if (model_->regex_model()) {
421*993b0882SAndroid Build Coastguard Worker     if (!InitializeRegexModel(decompressor.get())) {
422*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize regex model.";
423*993b0882SAndroid Build Coastguard Worker       return;
424*993b0882SAndroid Build Coastguard Worker     }
425*993b0882SAndroid Build Coastguard Worker   }
426*993b0882SAndroid Build Coastguard Worker 
427*993b0882SAndroid Build Coastguard Worker   if (model_->datetime_grammar_model()) {
428*993b0882SAndroid Build Coastguard Worker     if (model_->datetime_grammar_model()->rules()) {
429*993b0882SAndroid Build Coastguard Worker       analyzer_ = std::make_unique<grammar::Analyzer>(
430*993b0882SAndroid Build Coastguard Worker           unilib_, model_->datetime_grammar_model()->rules());
431*993b0882SAndroid Build Coastguard Worker       datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
432*993b0882SAndroid Build Coastguard Worker       datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
433*993b0882SAndroid Build Coastguard Worker           *analyzer_, *datetime_grounder_,
434*993b0882SAndroid Build Coastguard Worker           /*target_classification_score=*/1.0,
435*993b0882SAndroid Build Coastguard Worker           /*priority_score=*/1.0,
436*993b0882SAndroid Build Coastguard Worker           model_->datetime_grammar_model()->enabled_modes());
437*993b0882SAndroid Build Coastguard Worker     }
438*993b0882SAndroid Build Coastguard Worker   } else if (model_->datetime_model()) {
439*993b0882SAndroid Build Coastguard Worker     datetime_parser_ = RegexDatetimeParser::Instance(
440*993b0882SAndroid Build Coastguard Worker         model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
441*993b0882SAndroid Build Coastguard Worker     if (!datetime_parser_) {
442*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
443*993b0882SAndroid Build Coastguard Worker       return;
444*993b0882SAndroid Build Coastguard Worker     }
445*993b0882SAndroid Build Coastguard Worker   }
446*993b0882SAndroid Build Coastguard Worker 
447*993b0882SAndroid Build Coastguard Worker   if (model_->output_options()) {
448*993b0882SAndroid Build Coastguard Worker     if (model_->output_options()->filtered_collections_annotation()) {
449*993b0882SAndroid Build Coastguard Worker       for (const auto collection :
450*993b0882SAndroid Build Coastguard Worker            *model_->output_options()->filtered_collections_annotation()) {
451*993b0882SAndroid Build Coastguard Worker         filtered_collections_annotation_.insert(collection->str());
452*993b0882SAndroid Build Coastguard Worker       }
453*993b0882SAndroid Build Coastguard Worker     }
454*993b0882SAndroid Build Coastguard Worker     if (model_->output_options()->filtered_collections_classification()) {
455*993b0882SAndroid Build Coastguard Worker       for (const auto collection :
456*993b0882SAndroid Build Coastguard Worker            *model_->output_options()->filtered_collections_classification()) {
457*993b0882SAndroid Build Coastguard Worker         filtered_collections_classification_.insert(collection->str());
458*993b0882SAndroid Build Coastguard Worker       }
459*993b0882SAndroid Build Coastguard Worker     }
460*993b0882SAndroid Build Coastguard Worker     if (model_->output_options()->filtered_collections_selection()) {
461*993b0882SAndroid Build Coastguard Worker       for (const auto collection :
462*993b0882SAndroid Build Coastguard Worker            *model_->output_options()->filtered_collections_selection()) {
463*993b0882SAndroid Build Coastguard Worker         filtered_collections_selection_.insert(collection->str());
464*993b0882SAndroid Build Coastguard Worker       }
465*993b0882SAndroid Build Coastguard Worker     }
466*993b0882SAndroid Build Coastguard Worker   }
467*993b0882SAndroid Build Coastguard Worker 
468*993b0882SAndroid Build Coastguard Worker   if (model_->number_annotator_options() &&
469*993b0882SAndroid Build Coastguard Worker       model_->number_annotator_options()->enabled()) {
470*993b0882SAndroid Build Coastguard Worker     number_annotator_.reset(
471*993b0882SAndroid Build Coastguard Worker         new NumberAnnotator(model_->number_annotator_options(), unilib_));
472*993b0882SAndroid Build Coastguard Worker   }
473*993b0882SAndroid Build Coastguard Worker 
474*993b0882SAndroid Build Coastguard Worker   if (model_->money_parsing_options()) {
475*993b0882SAndroid Build Coastguard Worker     money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
476*993b0882SAndroid Build Coastguard Worker         model_->money_parsing_options()->separators());
477*993b0882SAndroid Build Coastguard Worker   }
478*993b0882SAndroid Build Coastguard Worker 
479*993b0882SAndroid Build Coastguard Worker   if (model_->duration_annotator_options() &&
480*993b0882SAndroid Build Coastguard Worker       model_->duration_annotator_options()->enabled()) {
481*993b0882SAndroid Build Coastguard Worker     duration_annotator_.reset(
482*993b0882SAndroid Build Coastguard Worker         new DurationAnnotator(model_->duration_annotator_options(),
483*993b0882SAndroid Build Coastguard Worker                               selection_feature_processor_.get(), unilib_));
484*993b0882SAndroid Build Coastguard Worker   }
485*993b0882SAndroid Build Coastguard Worker 
486*993b0882SAndroid Build Coastguard Worker   if (model_->grammar_model()) {
487*993b0882SAndroid Build Coastguard Worker     grammar_annotator_.reset(new GrammarAnnotator(
488*993b0882SAndroid Build Coastguard Worker         unilib_, model_->grammar_model(), entity_data_builder_.get()));
489*993b0882SAndroid Build Coastguard Worker   }
490*993b0882SAndroid Build Coastguard Worker 
491*993b0882SAndroid Build Coastguard Worker   // The following #ifdef is here to aid quality evaluation of a situation, when
492*993b0882SAndroid Build Coastguard Worker   // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
493*993b0882SAndroid Build Coastguard Worker   // it.
494*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_POD_NER)
495*993b0882SAndroid Build Coastguard Worker   if (model_->pod_ner_model()) {
496*993b0882SAndroid Build Coastguard Worker     pod_ner_annotator_ =
497*993b0882SAndroid Build Coastguard Worker         PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
498*993b0882SAndroid Build Coastguard Worker   }
499*993b0882SAndroid Build Coastguard Worker #endif
500*993b0882SAndroid Build Coastguard Worker 
501*993b0882SAndroid Build Coastguard Worker   if (model_->vocab_model()) {
502*993b0882SAndroid Build Coastguard Worker     vocab_annotator_ = VocabAnnotator::Create(
503*993b0882SAndroid Build Coastguard Worker         model_->vocab_model(), *selection_feature_processor_, *unilib_);
504*993b0882SAndroid Build Coastguard Worker   }
505*993b0882SAndroid Build Coastguard Worker 
506*993b0882SAndroid Build Coastguard Worker   if (model_->entity_data_schema()) {
507*993b0882SAndroid Build Coastguard Worker     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
508*993b0882SAndroid Build Coastguard Worker         model_->entity_data_schema()->Data(),
509*993b0882SAndroid Build Coastguard Worker         model_->entity_data_schema()->size());
510*993b0882SAndroid Build Coastguard Worker     if (entity_data_schema_ == nullptr) {
511*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not load entity data schema data.";
512*993b0882SAndroid Build Coastguard Worker       return;
513*993b0882SAndroid Build Coastguard Worker     }
514*993b0882SAndroid Build Coastguard Worker 
515*993b0882SAndroid Build Coastguard Worker     entity_data_builder_.reset(
516*993b0882SAndroid Build Coastguard Worker         new MutableFlatbufferBuilder(entity_data_schema_));
517*993b0882SAndroid Build Coastguard Worker   } else {
518*993b0882SAndroid Build Coastguard Worker     entity_data_schema_ = nullptr;
519*993b0882SAndroid Build Coastguard Worker     entity_data_builder_ = nullptr;
520*993b0882SAndroid Build Coastguard Worker   }
521*993b0882SAndroid Build Coastguard Worker 
522*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_locales() &&
523*993b0882SAndroid Build Coastguard Worker       !ParseLocales(model_->triggering_locales()->c_str(),
524*993b0882SAndroid Build Coastguard Worker                     &model_triggering_locales_)) {
525*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse model supported locales.";
526*993b0882SAndroid Build Coastguard Worker     return;
527*993b0882SAndroid Build Coastguard Worker   }
528*993b0882SAndroid Build Coastguard Worker 
529*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() != nullptr &&
530*993b0882SAndroid Build Coastguard Worker       model_->triggering_options()->locales() != nullptr &&
531*993b0882SAndroid Build Coastguard Worker       !ParseLocales(model_->triggering_options()->locales()->c_str(),
532*993b0882SAndroid Build Coastguard Worker                     &ml_model_triggering_locales_)) {
533*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
534*993b0882SAndroid Build Coastguard Worker     return;
535*993b0882SAndroid Build Coastguard Worker   }
536*993b0882SAndroid Build Coastguard Worker 
537*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() != nullptr &&
538*993b0882SAndroid Build Coastguard Worker       model_->triggering_options()->dictionary_locales() != nullptr &&
539*993b0882SAndroid Build Coastguard Worker       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
540*993b0882SAndroid Build Coastguard Worker                     &dictionary_locales_)) {
541*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
542*993b0882SAndroid Build Coastguard Worker     return;
543*993b0882SAndroid Build Coastguard Worker   }
544*993b0882SAndroid Build Coastguard Worker 
545*993b0882SAndroid Build Coastguard Worker   if (model_->conflict_resolution_options() != nullptr) {
546*993b0882SAndroid Build Coastguard Worker     prioritize_longest_annotation_ =
547*993b0882SAndroid Build Coastguard Worker         model_->conflict_resolution_options()->prioritize_longest_annotation();
548*993b0882SAndroid Build Coastguard Worker     do_conflict_resolution_in_raw_mode_ =
549*993b0882SAndroid Build Coastguard Worker         model_->conflict_resolution_options()
550*993b0882SAndroid Build Coastguard Worker             ->do_conflict_resolution_in_raw_mode();
551*993b0882SAndroid Build Coastguard Worker   }
552*993b0882SAndroid Build Coastguard Worker 
553*993b0882SAndroid Build Coastguard Worker #ifdef TC3_EXPERIMENTAL
554*993b0882SAndroid Build Coastguard Worker   TC3_LOG(WARNING) << "Enabling experimental annotators.";
555*993b0882SAndroid Build Coastguard Worker   InitializeExperimentalAnnotators();
556*993b0882SAndroid Build Coastguard Worker #endif
557*993b0882SAndroid Build Coastguard Worker 
558*993b0882SAndroid Build Coastguard Worker   initialized_ = true;
559*993b0882SAndroid Build Coastguard Worker }
560*993b0882SAndroid Build Coastguard Worker 
InitializeRegexModel(ZlibDecompressor * decompressor)561*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
562*993b0882SAndroid Build Coastguard Worker   if (!model_->regex_model()->patterns()) {
563*993b0882SAndroid Build Coastguard Worker     return true;
564*993b0882SAndroid Build Coastguard Worker   }
565*993b0882SAndroid Build Coastguard Worker 
566*993b0882SAndroid Build Coastguard Worker   // Initialize pattern recognizers.
567*993b0882SAndroid Build Coastguard Worker   int regex_pattern_id = 0;
568*993b0882SAndroid Build Coastguard Worker   for (const auto regex_pattern : *model_->regex_model()->patterns()) {
569*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
570*993b0882SAndroid Build Coastguard Worker         UncompressMakeRegexPattern(
571*993b0882SAndroid Build Coastguard Worker             *unilib_, regex_pattern->pattern(),
572*993b0882SAndroid Build Coastguard Worker             regex_pattern->compressed_pattern(),
573*993b0882SAndroid Build Coastguard Worker             model_->regex_model()->lazy_regex_compilation(), decompressor);
574*993b0882SAndroid Build Coastguard Worker     if (!compiled_pattern) {
575*993b0882SAndroid Build Coastguard Worker       TC3_LOG(INFO) << "Failed to load regex pattern";
576*993b0882SAndroid Build Coastguard Worker       return false;
577*993b0882SAndroid Build Coastguard Worker     }
578*993b0882SAndroid Build Coastguard Worker 
579*993b0882SAndroid Build Coastguard Worker     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
580*993b0882SAndroid Build Coastguard Worker       annotation_regex_patterns_.push_back(regex_pattern_id);
581*993b0882SAndroid Build Coastguard Worker     }
582*993b0882SAndroid Build Coastguard Worker     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
583*993b0882SAndroid Build Coastguard Worker       classification_regex_patterns_.push_back(regex_pattern_id);
584*993b0882SAndroid Build Coastguard Worker     }
585*993b0882SAndroid Build Coastguard Worker     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
586*993b0882SAndroid Build Coastguard Worker       selection_regex_patterns_.push_back(regex_pattern_id);
587*993b0882SAndroid Build Coastguard Worker     }
588*993b0882SAndroid Build Coastguard Worker     regex_patterns_.push_back({
589*993b0882SAndroid Build Coastguard Worker         regex_pattern,
590*993b0882SAndroid Build Coastguard Worker         std::move(compiled_pattern),
591*993b0882SAndroid Build Coastguard Worker     });
592*993b0882SAndroid Build Coastguard Worker     ++regex_pattern_id;
593*993b0882SAndroid Build Coastguard Worker   }
594*993b0882SAndroid Build Coastguard Worker 
595*993b0882SAndroid Build Coastguard Worker   return true;
596*993b0882SAndroid Build Coastguard Worker }
597*993b0882SAndroid Build Coastguard Worker 
InitializeKnowledgeEngine(const std::string & serialized_config)598*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializeKnowledgeEngine(
599*993b0882SAndroid Build Coastguard Worker     const std::string& serialized_config) {
600*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
601*993b0882SAndroid Build Coastguard Worker   if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
602*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
603*993b0882SAndroid Build Coastguard Worker     return false;
604*993b0882SAndroid Build Coastguard Worker   }
605*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() != nullptr) {
606*993b0882SAndroid Build Coastguard Worker     knowledge_engine->SetPriorityScore(
607*993b0882SAndroid Build Coastguard Worker         model_->triggering_options()->knowledge_priority_score());
608*993b0882SAndroid Build Coastguard Worker     knowledge_engine->SetEnabledModes(
609*993b0882SAndroid Build Coastguard Worker         model_->triggering_options()->knowledge_enabled_modes());
610*993b0882SAndroid Build Coastguard Worker   }
611*993b0882SAndroid Build Coastguard Worker   knowledge_engine_ = std::move(knowledge_engine);
612*993b0882SAndroid Build Coastguard Worker   return true;
613*993b0882SAndroid Build Coastguard Worker }
614*993b0882SAndroid Build Coastguard Worker 
InitializeContactEngine(const std::string & serialized_config)615*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
616*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ContactEngine> contact_engine(
617*993b0882SAndroid Build Coastguard Worker       new ContactEngine(selection_feature_processor_.get(), unilib_,
618*993b0882SAndroid Build Coastguard Worker                         model_->contact_annotator_options()));
619*993b0882SAndroid Build Coastguard Worker   if (!contact_engine->Initialize(serialized_config)) {
620*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
621*993b0882SAndroid Build Coastguard Worker     return false;
622*993b0882SAndroid Build Coastguard Worker   }
623*993b0882SAndroid Build Coastguard Worker   contact_engine_ = std::move(contact_engine);
624*993b0882SAndroid Build Coastguard Worker   return true;
625*993b0882SAndroid Build Coastguard Worker }
626*993b0882SAndroid Build Coastguard Worker 
CleanUpContactEngine()627*993b0882SAndroid Build Coastguard Worker void Annotator::CleanUpContactEngine() {
628*993b0882SAndroid Build Coastguard Worker   if (contact_engine_ == nullptr) {
629*993b0882SAndroid Build Coastguard Worker     TC3_LOG(INFO)
630*993b0882SAndroid Build Coastguard Worker         << "Attempting to clean up contact engine that does not exist.";
631*993b0882SAndroid Build Coastguard Worker     return;
632*993b0882SAndroid Build Coastguard Worker   }
633*993b0882SAndroid Build Coastguard Worker   contact_engine_->CleanUp();
634*993b0882SAndroid Build Coastguard Worker }
635*993b0882SAndroid Build Coastguard Worker 
InitializeInstalledAppEngine(const std::string & serialized_config)636*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializeInstalledAppEngine(
637*993b0882SAndroid Build Coastguard Worker     const std::string& serialized_config) {
638*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<InstalledAppEngine> installed_app_engine(
639*993b0882SAndroid Build Coastguard Worker       new InstalledAppEngine(
640*993b0882SAndroid Build Coastguard Worker           selection_feature_processor_.get(), unilib_,
641*993b0882SAndroid Build Coastguard Worker           model_->triggering_options()->installed_app_enabled_modes()));
642*993b0882SAndroid Build Coastguard Worker   if (!installed_app_engine->Initialize(serialized_config)) {
643*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
644*993b0882SAndroid Build Coastguard Worker     return false;
645*993b0882SAndroid Build Coastguard Worker   }
646*993b0882SAndroid Build Coastguard Worker   installed_app_engine_ = std::move(installed_app_engine);
647*993b0882SAndroid Build Coastguard Worker   return true;
648*993b0882SAndroid Build Coastguard Worker }
649*993b0882SAndroid Build Coastguard Worker 
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)650*993b0882SAndroid Build Coastguard Worker bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
651*993b0882SAndroid Build Coastguard Worker   if (lang_id == nullptr) {
652*993b0882SAndroid Build Coastguard Worker     return false;
653*993b0882SAndroid Build Coastguard Worker   }
654*993b0882SAndroid Build Coastguard Worker 
655*993b0882SAndroid Build Coastguard Worker   lang_id_ = lang_id;
656*993b0882SAndroid Build Coastguard Worker   if (lang_id_ != nullptr && model_->translate_annotator_options() &&
657*993b0882SAndroid Build Coastguard Worker       model_->translate_annotator_options()->enabled()) {
658*993b0882SAndroid Build Coastguard Worker     translate_annotator_.reset(new TranslateAnnotator(
659*993b0882SAndroid Build Coastguard Worker         model_->translate_annotator_options(), lang_id_, unilib_));
660*993b0882SAndroid Build Coastguard Worker   } else {
661*993b0882SAndroid Build Coastguard Worker     translate_annotator_.reset(nullptr);
662*993b0882SAndroid Build Coastguard Worker   }
663*993b0882SAndroid Build Coastguard Worker   return true;
664*993b0882SAndroid Build Coastguard Worker }
665*993b0882SAndroid Build Coastguard Worker 
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)666*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
667*993b0882SAndroid Build Coastguard Worker                                                             int size) {
668*993b0882SAndroid Build Coastguard Worker   const PersonNameModel* person_name_model =
669*993b0882SAndroid Build Coastguard Worker       LoadAndVerifyPersonNameModel(buffer, size);
670*993b0882SAndroid Build Coastguard Worker 
671*993b0882SAndroid Build Coastguard Worker   if (person_name_model == nullptr) {
672*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Person name model verification failed.";
673*993b0882SAndroid Build Coastguard Worker     return false;
674*993b0882SAndroid Build Coastguard Worker   }
675*993b0882SAndroid Build Coastguard Worker 
676*993b0882SAndroid Build Coastguard Worker   if (!person_name_model->enabled()) {
677*993b0882SAndroid Build Coastguard Worker     return true;
678*993b0882SAndroid Build Coastguard Worker   }
679*993b0882SAndroid Build Coastguard Worker 
680*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<PersonNameEngine> person_name_engine(
681*993b0882SAndroid Build Coastguard Worker       new PersonNameEngine(selection_feature_processor_.get(), unilib_));
682*993b0882SAndroid Build Coastguard Worker   if (!person_name_engine->Initialize(person_name_model)) {
683*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
684*993b0882SAndroid Build Coastguard Worker     return false;
685*993b0882SAndroid Build Coastguard Worker   }
686*993b0882SAndroid Build Coastguard Worker   person_name_engine_ = std::move(person_name_engine);
687*993b0882SAndroid Build Coastguard Worker   return true;
688*993b0882SAndroid Build Coastguard Worker }
689*993b0882SAndroid Build Coastguard Worker 
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)690*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializePersonNameEngineFromScopedMmap(
691*993b0882SAndroid Build Coastguard Worker     const ScopedMmap& mmap) {
692*993b0882SAndroid Build Coastguard Worker   if (!mmap.handle().ok()) {
693*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Mmap for person name model failed.";
694*993b0882SAndroid Build Coastguard Worker     return false;
695*993b0882SAndroid Build Coastguard Worker   }
696*993b0882SAndroid Build Coastguard Worker 
697*993b0882SAndroid Build Coastguard Worker   return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
698*993b0882SAndroid Build Coastguard Worker                                                      mmap.handle().num_bytes());
699*993b0882SAndroid Build Coastguard Worker }
700*993b0882SAndroid Build Coastguard Worker 
InitializePersonNameEngineFromPath(const std::string & path)701*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
702*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
703*993b0882SAndroid Build Coastguard Worker   return InitializePersonNameEngineFromScopedMmap(*mmap);
704*993b0882SAndroid Build Coastguard Worker }
705*993b0882SAndroid Build Coastguard Worker 
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)706*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
707*993b0882SAndroid Build Coastguard Worker                                                              int size) {
708*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
709*993b0882SAndroid Build Coastguard Worker   return InitializePersonNameEngineFromScopedMmap(*mmap);
710*993b0882SAndroid Build Coastguard Worker }
711*993b0882SAndroid Build Coastguard Worker 
InitializeExperimentalAnnotators()712*993b0882SAndroid Build Coastguard Worker bool Annotator::InitializeExperimentalAnnotators() {
713*993b0882SAndroid Build Coastguard Worker   if (ExperimentalAnnotator::IsEnabled()) {
714*993b0882SAndroid Build Coastguard Worker     experimental_annotator_.reset(new ExperimentalAnnotator(
715*993b0882SAndroid Build Coastguard Worker         model_->experimental_model(), *selection_feature_processor_, *unilib_));
716*993b0882SAndroid Build Coastguard Worker     return true;
717*993b0882SAndroid Build Coastguard Worker   }
718*993b0882SAndroid Build Coastguard Worker   return false;
719*993b0882SAndroid Build Coastguard Worker }
720*993b0882SAndroid Build Coastguard Worker 
721*993b0882SAndroid Build Coastguard Worker namespace internal {
722*993b0882SAndroid Build Coastguard Worker // Helper function, which if the initial 'span' contains only white-spaces,
723*993b0882SAndroid Build Coastguard Worker // moves the selection to a single-codepoint selection on a left or right side
724*993b0882SAndroid Build Coastguard Worker // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)725*993b0882SAndroid Build Coastguard Worker CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
726*993b0882SAndroid Build Coastguard Worker                                             const UnicodeText& context_unicode,
727*993b0882SAndroid Build Coastguard Worker                                             const UniLib& unilib) {
728*993b0882SAndroid Build Coastguard Worker   TC3_CHECK(span.IsValid() && !span.IsEmpty());
729*993b0882SAndroid Build Coastguard Worker 
730*993b0882SAndroid Build Coastguard Worker   UnicodeText::const_iterator it;
731*993b0882SAndroid Build Coastguard Worker 
732*993b0882SAndroid Build Coastguard Worker   // Check that the current selection is all whitespaces.
733*993b0882SAndroid Build Coastguard Worker   it = context_unicode.begin();
734*993b0882SAndroid Build Coastguard Worker   std::advance(it, span.first);
735*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
736*993b0882SAndroid Build Coastguard Worker     if (!unilib.IsWhitespace(*it)) {
737*993b0882SAndroid Build Coastguard Worker       return span;
738*993b0882SAndroid Build Coastguard Worker     }
739*993b0882SAndroid Build Coastguard Worker   }
740*993b0882SAndroid Build Coastguard Worker 
741*993b0882SAndroid Build Coastguard Worker   // Try moving left.
742*993b0882SAndroid Build Coastguard Worker   CodepointSpan result = span;
743*993b0882SAndroid Build Coastguard Worker   it = context_unicode.begin();
744*993b0882SAndroid Build Coastguard Worker   std::advance(it, span.first);
745*993b0882SAndroid Build Coastguard Worker   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
746*993b0882SAndroid Build Coastguard Worker     --result.first;
747*993b0882SAndroid Build Coastguard Worker     --it;
748*993b0882SAndroid Build Coastguard Worker   }
749*993b0882SAndroid Build Coastguard Worker   result.second = result.first + 1;
750*993b0882SAndroid Build Coastguard Worker   if (!unilib.IsWhitespace(*it)) {
751*993b0882SAndroid Build Coastguard Worker     return result;
752*993b0882SAndroid Build Coastguard Worker   }
753*993b0882SAndroid Build Coastguard Worker 
754*993b0882SAndroid Build Coastguard Worker   // If moving left didn't find a non-whitespace character, just return the
755*993b0882SAndroid Build Coastguard Worker   // original span.
756*993b0882SAndroid Build Coastguard Worker   return span;
757*993b0882SAndroid Build Coastguard Worker }
758*993b0882SAndroid Build Coastguard Worker }  // namespace internal
759*993b0882SAndroid Build Coastguard Worker 
FilteredForAnnotation(const AnnotatedSpan & span) const760*993b0882SAndroid Build Coastguard Worker bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
761*993b0882SAndroid Build Coastguard Worker   return !span.classification.empty() &&
762*993b0882SAndroid Build Coastguard Worker          filtered_collections_annotation_.find(
763*993b0882SAndroid Build Coastguard Worker              span.classification[0].collection) !=
764*993b0882SAndroid Build Coastguard Worker              filtered_collections_annotation_.end();
765*993b0882SAndroid Build Coastguard Worker }
766*993b0882SAndroid Build Coastguard Worker 
FilteredForClassification(const ClassificationResult & classification) const767*993b0882SAndroid Build Coastguard Worker bool Annotator::FilteredForClassification(
768*993b0882SAndroid Build Coastguard Worker     const ClassificationResult& classification) const {
769*993b0882SAndroid Build Coastguard Worker   return filtered_collections_classification_.find(classification.collection) !=
770*993b0882SAndroid Build Coastguard Worker          filtered_collections_classification_.end();
771*993b0882SAndroid Build Coastguard Worker }
772*993b0882SAndroid Build Coastguard Worker 
FilteredForSelection(const AnnotatedSpan & span) const773*993b0882SAndroid Build Coastguard Worker bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
774*993b0882SAndroid Build Coastguard Worker   return !span.classification.empty() &&
775*993b0882SAndroid Build Coastguard Worker          filtered_collections_selection_.find(
776*993b0882SAndroid Build Coastguard Worker              span.classification[0].collection) !=
777*993b0882SAndroid Build Coastguard Worker              filtered_collections_selection_.end();
778*993b0882SAndroid Build Coastguard Worker }
779*993b0882SAndroid Build Coastguard Worker 
780*993b0882SAndroid Build Coastguard Worker namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)781*993b0882SAndroid Build Coastguard Worker inline bool ClassifiedAsOther(
782*993b0882SAndroid Build Coastguard Worker     const std::vector<ClassificationResult>& classification) {
783*993b0882SAndroid Build Coastguard Worker   return !classification.empty() &&
784*993b0882SAndroid Build Coastguard Worker          classification[0].collection == Collections::Other();
785*993b0882SAndroid Build Coastguard Worker }
786*993b0882SAndroid Build Coastguard Worker 
787*993b0882SAndroid Build Coastguard Worker }  // namespace
788*993b0882SAndroid Build Coastguard Worker 
GetPriorityScore(const std::vector<ClassificationResult> & classification) const789*993b0882SAndroid Build Coastguard Worker float Annotator::GetPriorityScore(
790*993b0882SAndroid Build Coastguard Worker     const std::vector<ClassificationResult>& classification) const {
791*993b0882SAndroid Build Coastguard Worker   if (!classification.empty() && !ClassifiedAsOther(classification)) {
792*993b0882SAndroid Build Coastguard Worker     return classification[0].priority_score;
793*993b0882SAndroid Build Coastguard Worker   } else {
794*993b0882SAndroid Build Coastguard Worker     if (model_->triggering_options() != nullptr) {
795*993b0882SAndroid Build Coastguard Worker       return model_->triggering_options()->other_collection_priority_score();
796*993b0882SAndroid Build Coastguard Worker     } else {
797*993b0882SAndroid Build Coastguard Worker       return -1000.0;
798*993b0882SAndroid Build Coastguard Worker     }
799*993b0882SAndroid Build Coastguard Worker   }
800*993b0882SAndroid Build Coastguard Worker }
801*993b0882SAndroid Build Coastguard Worker 
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const802*993b0882SAndroid Build Coastguard Worker bool Annotator::VerifyRegexMatchCandidate(
803*993b0882SAndroid Build Coastguard Worker     const std::string& context, const VerificationOptions* verification_options,
804*993b0882SAndroid Build Coastguard Worker     const std::string& match, const UniLib::RegexMatcher* matcher) const {
805*993b0882SAndroid Build Coastguard Worker   if (verification_options == nullptr) {
806*993b0882SAndroid Build Coastguard Worker     return true;
807*993b0882SAndroid Build Coastguard Worker   }
808*993b0882SAndroid Build Coastguard Worker   if (verification_options->verify_luhn_checksum() &&
809*993b0882SAndroid Build Coastguard Worker       !VerifyLuhnChecksum(match)) {
810*993b0882SAndroid Build Coastguard Worker     return false;
811*993b0882SAndroid Build Coastguard Worker   }
812*993b0882SAndroid Build Coastguard Worker   const int lua_verifier = verification_options->lua_verifier();
813*993b0882SAndroid Build Coastguard Worker   if (lua_verifier >= 0) {
814*993b0882SAndroid Build Coastguard Worker     if (model_->regex_model()->lua_verifier() == nullptr ||
815*993b0882SAndroid Build Coastguard Worker         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
816*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
817*993b0882SAndroid Build Coastguard Worker       return false;
818*993b0882SAndroid Build Coastguard Worker     }
819*993b0882SAndroid Build Coastguard Worker     return VerifyMatch(
820*993b0882SAndroid Build Coastguard Worker         context, matcher,
821*993b0882SAndroid Build Coastguard Worker         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
822*993b0882SAndroid Build Coastguard Worker   }
823*993b0882SAndroid Build Coastguard Worker   return true;
824*993b0882SAndroid Build Coastguard Worker }
825*993b0882SAndroid Build Coastguard Worker 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const826*993b0882SAndroid Build Coastguard Worker CodepointSpan Annotator::SuggestSelection(
827*993b0882SAndroid Build Coastguard Worker     const std::string& context, CodepointSpan click_indices,
828*993b0882SAndroid Build Coastguard Worker     const SelectionOptions& options) const {
829*993b0882SAndroid Build Coastguard Worker   if (context.size() > std::numeric_limits<int>::max()) {
830*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
831*993b0882SAndroid Build Coastguard Worker     return {};
832*993b0882SAndroid Build Coastguard Worker   }
833*993b0882SAndroid Build Coastguard Worker 
834*993b0882SAndroid Build Coastguard Worker   CodepointSpan original_click_indices = click_indices;
835*993b0882SAndroid Build Coastguard Worker   if (!initialized_) {
836*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Not initialized";
837*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
838*993b0882SAndroid Build Coastguard Worker   }
839*993b0882SAndroid Build Coastguard Worker   if (options.annotation_usecase !=
840*993b0882SAndroid Build Coastguard Worker       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
841*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING)
842*993b0882SAndroid Build Coastguard Worker         << "Invoking SuggestSelection, which is not supported in RAW mode.";
843*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
844*993b0882SAndroid Build Coastguard Worker   }
845*993b0882SAndroid Build Coastguard Worker   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
846*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
847*993b0882SAndroid Build Coastguard Worker   }
848*993b0882SAndroid Build Coastguard Worker 
849*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> detected_text_language_tags;
850*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(options.detected_text_language_tags,
851*993b0882SAndroid Build Coastguard Worker                     &detected_text_language_tags)) {
852*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING)
853*993b0882SAndroid Build Coastguard Worker         << "Failed to parse the detected_text_language_tags in options: "
854*993b0882SAndroid Build Coastguard Worker         << options.detected_text_language_tags;
855*993b0882SAndroid Build Coastguard Worker   }
856*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
857*993b0882SAndroid Build Coastguard Worker                                     model_triggering_locales_,
858*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
859*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
860*993b0882SAndroid Build Coastguard Worker   }
861*993b0882SAndroid Build Coastguard Worker 
862*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
863*993b0882SAndroid Build Coastguard Worker                                                         /*do_copy=*/false);
864*993b0882SAndroid Build Coastguard Worker 
865*993b0882SAndroid Build Coastguard Worker   if (!unilib_->IsValidUtf8(context_unicode)) {
866*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
867*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
868*993b0882SAndroid Build Coastguard Worker   }
869*993b0882SAndroid Build Coastguard Worker 
870*993b0882SAndroid Build Coastguard Worker   if (!IsValidSpanInput(context_unicode, click_indices)) {
871*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1)
872*993b0882SAndroid Build Coastguard Worker         << "Trying to run SuggestSelection with invalid input, indices: "
873*993b0882SAndroid Build Coastguard Worker         << click_indices.first << " " << click_indices.second;
874*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
875*993b0882SAndroid Build Coastguard Worker   }
876*993b0882SAndroid Build Coastguard Worker 
877*993b0882SAndroid Build Coastguard Worker   if (model_->snap_whitespace_selections()) {
878*993b0882SAndroid Build Coastguard Worker     // We want to expand a purely white-space selection to a multi-selection it
879*993b0882SAndroid Build Coastguard Worker     // would've been part of. But with this feature disabled we would do a no-
880*993b0882SAndroid Build Coastguard Worker     // op, because no token is found. Therefore, we need to modify the
881*993b0882SAndroid Build Coastguard Worker     // 'click_indices' a bit to include a part of the token, so that the click-
882*993b0882SAndroid Build Coastguard Worker     // finding logic finds the clicked token correctly. This modification is
883*993b0882SAndroid Build Coastguard Worker     // done by the following function. Note, that it's enough to check the left
884*993b0882SAndroid Build Coastguard Worker     // side of the current selection, because if the white-space is a part of a
885*993b0882SAndroid Build Coastguard Worker     // multi-selection, necessarily both tokens - on the left and the right
886*993b0882SAndroid Build Coastguard Worker     // sides need to be selected. Thus snapping only to the left is sufficient
887*993b0882SAndroid Build Coastguard Worker     // (there's a check at the bottom that makes sure that if we snap to the
888*993b0882SAndroid Build Coastguard Worker     // left token but the result does not contain the initial white-space,
889*993b0882SAndroid Build Coastguard Worker     // returns the original indices).
890*993b0882SAndroid Build Coastguard Worker     click_indices = internal::SnapLeftIfWhitespaceSelection(
891*993b0882SAndroid Build Coastguard Worker         click_indices, context_unicode, *unilib_);
892*993b0882SAndroid Build Coastguard Worker   }
893*993b0882SAndroid Build Coastguard Worker 
894*993b0882SAndroid Build Coastguard Worker   Annotations candidates;
895*993b0882SAndroid Build Coastguard Worker   // As we process a single string of context, the candidates will only
896*993b0882SAndroid Build Coastguard Worker   // contain one vector of AnnotatedSpan.
897*993b0882SAndroid Build Coastguard Worker   candidates.annotated_spans.resize(1);
898*993b0882SAndroid Build Coastguard Worker   InterpreterManager interpreter_manager(selection_executor_.get(),
899*993b0882SAndroid Build Coastguard Worker                                          classification_executor_.get());
900*993b0882SAndroid Build Coastguard Worker   std::vector<Token> tokens;
901*993b0882SAndroid Build Coastguard Worker   if (!ModelSuggestSelection(context_unicode, click_indices,
902*993b0882SAndroid Build Coastguard Worker                              detected_text_language_tags, &interpreter_manager,
903*993b0882SAndroid Build Coastguard Worker                              &tokens, &candidates.annotated_spans[0])) {
904*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Model suggest selection failed.";
905*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
906*993b0882SAndroid Build Coastguard Worker   }
907*993b0882SAndroid Build Coastguard Worker   const std::unordered_set<std::string> set;
908*993b0882SAndroid Build Coastguard Worker   const EnabledEntityTypes is_entity_type_enabled(set);
909*993b0882SAndroid Build Coastguard Worker   if (!RegexChunk(context_unicode, selection_regex_patterns_,
910*993b0882SAndroid Build Coastguard Worker                   /*is_serialized_entity_data_enabled=*/false,
911*993b0882SAndroid Build Coastguard Worker                   is_entity_type_enabled, options.annotation_usecase,
912*993b0882SAndroid Build Coastguard Worker                   &candidates.annotated_spans[0])) {
913*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Regex suggest selection failed.";
914*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
915*993b0882SAndroid Build Coastguard Worker   }
916*993b0882SAndroid Build Coastguard Worker   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
917*993b0882SAndroid Build Coastguard Worker                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
918*993b0882SAndroid Build Coastguard Worker                      options.locales, ModeFlag_SELECTION,
919*993b0882SAndroid Build Coastguard Worker                      options.annotation_usecase,
920*993b0882SAndroid Build Coastguard Worker                      /*is_serialized_entity_data_enabled=*/false,
921*993b0882SAndroid Build Coastguard Worker                      &candidates.annotated_spans[0])) {
922*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
923*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
924*993b0882SAndroid Build Coastguard Worker   }
925*993b0882SAndroid Build Coastguard Worker   if (knowledge_engine_ != nullptr &&
926*993b0882SAndroid Build Coastguard Worker       !knowledge_engine_
927*993b0882SAndroid Build Coastguard Worker            ->Chunk(context, options.annotation_usecase,
928*993b0882SAndroid Build Coastguard Worker                    options.location_context, Permissions(),
929*993b0882SAndroid Build Coastguard Worker                    AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION,
930*993b0882SAndroid Build Coastguard Worker                    &candidates)
931*993b0882SAndroid Build Coastguard Worker            .ok()) {
932*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
933*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
934*993b0882SAndroid Build Coastguard Worker   }
935*993b0882SAndroid Build Coastguard Worker   if (contact_engine_ != nullptr &&
936*993b0882SAndroid Build Coastguard Worker       !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
937*993b0882SAndroid Build Coastguard Worker                               &candidates.annotated_spans[0])) {
938*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Contact suggest selection failed.";
939*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
940*993b0882SAndroid Build Coastguard Worker   }
941*993b0882SAndroid Build Coastguard Worker   if (installed_app_engine_ != nullptr &&
942*993b0882SAndroid Build Coastguard Worker       !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
943*993b0882SAndroid Build Coastguard Worker                                     &candidates.annotated_spans[0])) {
944*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
945*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
946*993b0882SAndroid Build Coastguard Worker   }
947*993b0882SAndroid Build Coastguard Worker   if (number_annotator_ != nullptr &&
948*993b0882SAndroid Build Coastguard Worker       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
949*993b0882SAndroid Build Coastguard Worker                                   ModeFlag_SELECTION,
950*993b0882SAndroid Build Coastguard Worker                                   &candidates.annotated_spans[0])) {
951*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
952*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
953*993b0882SAndroid Build Coastguard Worker   }
954*993b0882SAndroid Build Coastguard Worker   if (duration_annotator_ != nullptr &&
955*993b0882SAndroid Build Coastguard Worker       !duration_annotator_->FindAll(
956*993b0882SAndroid Build Coastguard Worker           context_unicode, tokens, options.annotation_usecase,
957*993b0882SAndroid Build Coastguard Worker           ModeFlag_SELECTION, &candidates.annotated_spans[0])) {
958*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
959*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
960*993b0882SAndroid Build Coastguard Worker   }
961*993b0882SAndroid Build Coastguard Worker   if (person_name_engine_ != nullptr &&
962*993b0882SAndroid Build Coastguard Worker       !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
963*993b0882SAndroid Build Coastguard Worker                                   &candidates.annotated_spans[0])) {
964*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Person name suggest selection failed.";
965*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
966*993b0882SAndroid Build Coastguard Worker   }
967*993b0882SAndroid Build Coastguard Worker 
968*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan grammar_suggested_span;
969*993b0882SAndroid Build Coastguard Worker   if (grammar_annotator_ != nullptr &&
970*993b0882SAndroid Build Coastguard Worker       grammar_annotator_->SuggestSelection(detected_text_language_tags,
971*993b0882SAndroid Build Coastguard Worker                                            context_unicode, click_indices,
972*993b0882SAndroid Build Coastguard Worker                                            &grammar_suggested_span)) {
973*993b0882SAndroid Build Coastguard Worker     candidates.annotated_spans[0].push_back(grammar_suggested_span);
974*993b0882SAndroid Build Coastguard Worker   }
975*993b0882SAndroid Build Coastguard Worker 
976*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan pod_ner_suggested_span;
977*993b0882SAndroid Build Coastguard Worker   if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
978*993b0882SAndroid Build Coastguard Worker       pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
979*993b0882SAndroid Build Coastguard Worker                                            &pod_ner_suggested_span)) {
980*993b0882SAndroid Build Coastguard Worker     candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
981*993b0882SAndroid Build Coastguard Worker   }
982*993b0882SAndroid Build Coastguard Worker 
983*993b0882SAndroid Build Coastguard Worker   if (experimental_annotator_ != nullptr &&
984*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options()->experimental_enabled_modes() &
985*993b0882SAndroid Build Coastguard Worker        ModeFlag_SELECTION)) {
986*993b0882SAndroid Build Coastguard Worker     candidates.annotated_spans[0].push_back(
987*993b0882SAndroid Build Coastguard Worker         experimental_annotator_->SuggestSelection(context_unicode,
988*993b0882SAndroid Build Coastguard Worker                                                   click_indices));
989*993b0882SAndroid Build Coastguard Worker   }
990*993b0882SAndroid Build Coastguard Worker 
991*993b0882SAndroid Build Coastguard Worker   // Sort candidates according to their position in the input, so that the next
992*993b0882SAndroid Build Coastguard Worker   // code can assume that any connected component of overlapping spans forms a
993*993b0882SAndroid Build Coastguard Worker   // contiguous block.
994*993b0882SAndroid Build Coastguard Worker   std::stable_sort(candidates.annotated_spans[0].begin(),
995*993b0882SAndroid Build Coastguard Worker                    candidates.annotated_spans[0].end(),
996*993b0882SAndroid Build Coastguard Worker                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
997*993b0882SAndroid Build Coastguard Worker                      return a.span.first < b.span.first;
998*993b0882SAndroid Build Coastguard Worker                    });
999*993b0882SAndroid Build Coastguard Worker 
1000*993b0882SAndroid Build Coastguard Worker   std::vector<int> candidate_indices;
1001*993b0882SAndroid Build Coastguard Worker   if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
1002*993b0882SAndroid Build Coastguard Worker                         detected_text_language_tags, options,
1003*993b0882SAndroid Build Coastguard Worker                         &interpreter_manager, &candidate_indices)) {
1004*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1005*993b0882SAndroid Build Coastguard Worker     return original_click_indices;
1006*993b0882SAndroid Build Coastguard Worker   }
1007*993b0882SAndroid Build Coastguard Worker 
1008*993b0882SAndroid Build Coastguard Worker   std::stable_sort(
1009*993b0882SAndroid Build Coastguard Worker       candidate_indices.begin(), candidate_indices.end(),
1010*993b0882SAndroid Build Coastguard Worker       [this, &candidates](int a, int b) {
1011*993b0882SAndroid Build Coastguard Worker         return GetPriorityScore(
1012*993b0882SAndroid Build Coastguard Worker                    candidates.annotated_spans[0][a].classification) >
1013*993b0882SAndroid Build Coastguard Worker                GetPriorityScore(
1014*993b0882SAndroid Build Coastguard Worker                    candidates.annotated_spans[0][b].classification);
1015*993b0882SAndroid Build Coastguard Worker       });
1016*993b0882SAndroid Build Coastguard Worker 
1017*993b0882SAndroid Build Coastguard Worker   for (const int i : candidate_indices) {
1018*993b0882SAndroid Build Coastguard Worker     if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1019*993b0882SAndroid Build Coastguard Worker         SpansOverlap(candidates.annotated_spans[0][i].span,
1020*993b0882SAndroid Build Coastguard Worker                      original_click_indices)) {
1021*993b0882SAndroid Build Coastguard Worker       // Run model classification if not present but requested and there's a
1022*993b0882SAndroid Build Coastguard Worker       // classification collection filter specified.
1023*993b0882SAndroid Build Coastguard Worker       if (candidates.annotated_spans[0][i].classification.empty() &&
1024*993b0882SAndroid Build Coastguard Worker           model_->selection_options()->always_classify_suggested_selection() &&
1025*993b0882SAndroid Build Coastguard Worker           !filtered_collections_selection_.empty()) {
1026*993b0882SAndroid Build Coastguard Worker         if (!ModelClassifyText(context, /*cached_tokens=*/{},
1027*993b0882SAndroid Build Coastguard Worker                                detected_text_language_tags,
1028*993b0882SAndroid Build Coastguard Worker                                candidates.annotated_spans[0][i].span, options,
1029*993b0882SAndroid Build Coastguard Worker                                &interpreter_manager,
1030*993b0882SAndroid Build Coastguard Worker                                /*embedding_cache=*/nullptr,
1031*993b0882SAndroid Build Coastguard Worker                                &candidates.annotated_spans[0][i].classification,
1032*993b0882SAndroid Build Coastguard Worker                                /*tokens=*/nullptr)) {
1033*993b0882SAndroid Build Coastguard Worker           return original_click_indices;
1034*993b0882SAndroid Build Coastguard Worker         }
1035*993b0882SAndroid Build Coastguard Worker       }
1036*993b0882SAndroid Build Coastguard Worker 
1037*993b0882SAndroid Build Coastguard Worker       // Ignore if span classification is filtered.
1038*993b0882SAndroid Build Coastguard Worker       if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1039*993b0882SAndroid Build Coastguard Worker         return original_click_indices;
1040*993b0882SAndroid Build Coastguard Worker       }
1041*993b0882SAndroid Build Coastguard Worker 
1042*993b0882SAndroid Build Coastguard Worker       // We return a suggested span contains the original span.
1043*993b0882SAndroid Build Coastguard Worker       // This compensates for "select all" selection that may come from
1044*993b0882SAndroid Build Coastguard Worker       // other apps. See http://b/179890518.
1045*993b0882SAndroid Build Coastguard Worker       if (SpanContains(candidates.annotated_spans[0][i].span,
1046*993b0882SAndroid Build Coastguard Worker                        original_click_indices)) {
1047*993b0882SAndroid Build Coastguard Worker         return candidates.annotated_spans[0][i].span;
1048*993b0882SAndroid Build Coastguard Worker       }
1049*993b0882SAndroid Build Coastguard Worker     }
1050*993b0882SAndroid Build Coastguard Worker   }
1051*993b0882SAndroid Build Coastguard Worker 
1052*993b0882SAndroid Build Coastguard Worker   return original_click_indices;
1053*993b0882SAndroid Build Coastguard Worker }
1054*993b0882SAndroid Build Coastguard Worker 
1055*993b0882SAndroid Build Coastguard Worker namespace {
1056*993b0882SAndroid Build Coastguard Worker // Helper function that returns the index of the first candidate that
1057*993b0882SAndroid Build Coastguard Worker // transitively does not overlap with the candidate on 'start_index'. If the end
1058*993b0882SAndroid Build Coastguard Worker // of 'candidates' is reached, it returns the index that points right behind the
1059*993b0882SAndroid Build Coastguard Worker // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1060*993b0882SAndroid Build Coastguard Worker int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1061*993b0882SAndroid Build Coastguard Worker                                  int start_index) {
1062*993b0882SAndroid Build Coastguard Worker   int first_non_overlapping = start_index + 1;
1063*993b0882SAndroid Build Coastguard Worker   CodepointSpan conflicting_span = candidates[start_index].span;
1064*993b0882SAndroid Build Coastguard Worker   while (
1065*993b0882SAndroid Build Coastguard Worker       first_non_overlapping < candidates.size() &&
1066*993b0882SAndroid Build Coastguard Worker       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1067*993b0882SAndroid Build Coastguard Worker     // Grow the span to include the current one.
1068*993b0882SAndroid Build Coastguard Worker     conflicting_span.second = std::max(
1069*993b0882SAndroid Build Coastguard Worker         conflicting_span.second, candidates[first_non_overlapping].span.second);
1070*993b0882SAndroid Build Coastguard Worker 
1071*993b0882SAndroid Build Coastguard Worker     ++first_non_overlapping;
1072*993b0882SAndroid Build Coastguard Worker   }
1073*993b0882SAndroid Build Coastguard Worker   return first_non_overlapping;
1074*993b0882SAndroid Build Coastguard Worker }
1075*993b0882SAndroid Build Coastguard Worker }  // namespace
1076*993b0882SAndroid Build Coastguard Worker 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1077*993b0882SAndroid Build Coastguard Worker bool Annotator::ResolveConflicts(
1078*993b0882SAndroid Build Coastguard Worker     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1079*993b0882SAndroid Build Coastguard Worker     const std::vector<Token>& cached_tokens,
1080*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags,
1081*993b0882SAndroid Build Coastguard Worker     const BaseOptions& options, InterpreterManager* interpreter_manager,
1082*993b0882SAndroid Build Coastguard Worker     std::vector<int>* result) const {
1083*993b0882SAndroid Build Coastguard Worker   result->clear();
1084*993b0882SAndroid Build Coastguard Worker   result->reserve(candidates.size());
1085*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < candidates.size();) {
1086*993b0882SAndroid Build Coastguard Worker     int first_non_overlapping =
1087*993b0882SAndroid Build Coastguard Worker         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1088*993b0882SAndroid Build Coastguard Worker 
1089*993b0882SAndroid Build Coastguard Worker     const bool conflict_found = first_non_overlapping != (i + 1);
1090*993b0882SAndroid Build Coastguard Worker     if (conflict_found) {
1091*993b0882SAndroid Build Coastguard Worker       std::vector<int> candidate_indices;
1092*993b0882SAndroid Build Coastguard Worker       if (!ResolveConflict(context, cached_tokens, candidates,
1093*993b0882SAndroid Build Coastguard Worker                            detected_text_language_tags, i,
1094*993b0882SAndroid Build Coastguard Worker                            first_non_overlapping, options, interpreter_manager,
1095*993b0882SAndroid Build Coastguard Worker                            &candidate_indices)) {
1096*993b0882SAndroid Build Coastguard Worker         return false;
1097*993b0882SAndroid Build Coastguard Worker       }
1098*993b0882SAndroid Build Coastguard Worker       result->insert(result->end(), candidate_indices.begin(),
1099*993b0882SAndroid Build Coastguard Worker                      candidate_indices.end());
1100*993b0882SAndroid Build Coastguard Worker     } else {
1101*993b0882SAndroid Build Coastguard Worker       result->push_back(i);
1102*993b0882SAndroid Build Coastguard Worker     }
1103*993b0882SAndroid Build Coastguard Worker 
1104*993b0882SAndroid Build Coastguard Worker     // Skip over the whole conflicting group/go to next candidate.
1105*993b0882SAndroid Build Coastguard Worker     i = first_non_overlapping;
1106*993b0882SAndroid Build Coastguard Worker   }
1107*993b0882SAndroid Build Coastguard Worker   return true;
1108*993b0882SAndroid Build Coastguard Worker }
1109*993b0882SAndroid Build Coastguard Worker 
1110*993b0882SAndroid Build Coastguard Worker namespace {
1111*993b0882SAndroid Build Coastguard Worker // Returns true, if the given two sources do conflict in given annotation
1112*993b0882SAndroid Build Coastguard Worker // usecase.
1113*993b0882SAndroid Build Coastguard Worker //  - In SMART usecase, all sources do conflict, because there's only 1 possible
1114*993b0882SAndroid Build Coastguard Worker //  annotation for a given span.
1115*993b0882SAndroid Build Coastguard Worker //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1116*993b0882SAndroid Build Coastguard Worker //  and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1117*993b0882SAndroid Build Coastguard Worker bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1118*993b0882SAndroid Build Coastguard Worker                        const AnnotatedSpan::Source source1,
1119*993b0882SAndroid Build Coastguard Worker                        const AnnotatedSpan::Source source2) {
1120*993b0882SAndroid Build Coastguard Worker   uint32 source_mask =
1121*993b0882SAndroid Build Coastguard Worker       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1122*993b0882SAndroid Build Coastguard Worker 
1123*993b0882SAndroid Build Coastguard Worker   switch (annotation_usecase) {
1124*993b0882SAndroid Build Coastguard Worker     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1125*993b0882SAndroid Build Coastguard Worker       // In the SMART mode, all annotations conflict.
1126*993b0882SAndroid Build Coastguard Worker       return true;
1127*993b0882SAndroid Build Coastguard Worker 
1128*993b0882SAndroid Build Coastguard Worker     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1129*993b0882SAndroid Build Coastguard Worker       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1130*993b0882SAndroid Build Coastguard Worker       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1131*993b0882SAndroid Build Coastguard Worker       // hours" (duration).
1132*993b0882SAndroid Build Coastguard Worker       if ((source_mask &
1133*993b0882SAndroid Build Coastguard Worker            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1134*993b0882SAndroid Build Coastguard Worker           (source_mask &
1135*993b0882SAndroid Build Coastguard Worker            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1136*993b0882SAndroid Build Coastguard Worker         return false;
1137*993b0882SAndroid Build Coastguard Worker       }
1138*993b0882SAndroid Build Coastguard Worker 
1139*993b0882SAndroid Build Coastguard Worker       // A KNOWLEDGE entity does not conflict with anything.
1140*993b0882SAndroid Build Coastguard Worker       if ((source_mask &
1141*993b0882SAndroid Build Coastguard Worker            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1142*993b0882SAndroid Build Coastguard Worker         return false;
1143*993b0882SAndroid Build Coastguard Worker       }
1144*993b0882SAndroid Build Coastguard Worker 
1145*993b0882SAndroid Build Coastguard Worker       // A PERSONNAME entity does not conflict with anything.
1146*993b0882SAndroid Build Coastguard Worker       if ((source_mask &
1147*993b0882SAndroid Build Coastguard Worker            (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1148*993b0882SAndroid Build Coastguard Worker         return false;
1149*993b0882SAndroid Build Coastguard Worker       }
1150*993b0882SAndroid Build Coastguard Worker 
1151*993b0882SAndroid Build Coastguard Worker       // Entities from other sources can conflict.
1152*993b0882SAndroid Build Coastguard Worker       return true;
1153*993b0882SAndroid Build Coastguard Worker   }
1154*993b0882SAndroid Build Coastguard Worker }
1155*993b0882SAndroid Build Coastguard Worker }  // namespace
1156*993b0882SAndroid Build Coastguard Worker 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1157*993b0882SAndroid Build Coastguard Worker bool Annotator::ResolveConflict(
1158*993b0882SAndroid Build Coastguard Worker     const std::string& context, const std::vector<Token>& cached_tokens,
1159*993b0882SAndroid Build Coastguard Worker     const std::vector<AnnotatedSpan>& candidates,
1160*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags, int start_index,
1161*993b0882SAndroid Build Coastguard Worker     int end_index, const BaseOptions& options,
1162*993b0882SAndroid Build Coastguard Worker     InterpreterManager* interpreter_manager,
1163*993b0882SAndroid Build Coastguard Worker     std::vector<int>* chosen_indices) const {
1164*993b0882SAndroid Build Coastguard Worker   std::vector<int> conflicting_indices;
1165*993b0882SAndroid Build Coastguard Worker   std::unordered_map<int, std::pair<float, int>> scores_lengths;
1166*993b0882SAndroid Build Coastguard Worker   for (int i = start_index; i < end_index; ++i) {
1167*993b0882SAndroid Build Coastguard Worker     conflicting_indices.push_back(i);
1168*993b0882SAndroid Build Coastguard Worker     if (!candidates[i].classification.empty()) {
1169*993b0882SAndroid Build Coastguard Worker       scores_lengths[i] = {
1170*993b0882SAndroid Build Coastguard Worker           GetPriorityScore(candidates[i].classification),
1171*993b0882SAndroid Build Coastguard Worker           candidates[i].span.second - candidates[i].span.first};
1172*993b0882SAndroid Build Coastguard Worker       continue;
1173*993b0882SAndroid Build Coastguard Worker     }
1174*993b0882SAndroid Build Coastguard Worker 
1175*993b0882SAndroid Build Coastguard Worker     // OPTIMIZATION: So that we don't have to classify all the ML model
1176*993b0882SAndroid Build Coastguard Worker     // spans apriori, we wait until we get here, when they conflict with
1177*993b0882SAndroid Build Coastguard Worker     // something and we need the actual classification scores. So if the
1178*993b0882SAndroid Build Coastguard Worker     // candidate conflicts and comes from the model, we need to run a
1179*993b0882SAndroid Build Coastguard Worker     // classification to determine its priority:
1180*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult> classification;
1181*993b0882SAndroid Build Coastguard Worker     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1182*993b0882SAndroid Build Coastguard Worker                            candidates[i].span, options, interpreter_manager,
1183*993b0882SAndroid Build Coastguard Worker                            /*embedding_cache=*/nullptr, &classification,
1184*993b0882SAndroid Build Coastguard Worker                            /*tokens=*/nullptr)) {
1185*993b0882SAndroid Build Coastguard Worker       return false;
1186*993b0882SAndroid Build Coastguard Worker     }
1187*993b0882SAndroid Build Coastguard Worker 
1188*993b0882SAndroid Build Coastguard Worker     if (!classification.empty()) {
1189*993b0882SAndroid Build Coastguard Worker       scores_lengths[i] = {
1190*993b0882SAndroid Build Coastguard Worker           GetPriorityScore(classification),
1191*993b0882SAndroid Build Coastguard Worker           candidates[i].span.second - candidates[i].span.first};
1192*993b0882SAndroid Build Coastguard Worker     }
1193*993b0882SAndroid Build Coastguard Worker   }
1194*993b0882SAndroid Build Coastguard Worker 
1195*993b0882SAndroid Build Coastguard Worker   std::stable_sort(
1196*993b0882SAndroid Build Coastguard Worker       conflicting_indices.begin(), conflicting_indices.end(),
1197*993b0882SAndroid Build Coastguard Worker       [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1198*993b0882SAndroid Build Coastguard Worker         if (scores_lengths[i].first == scores_lengths[j].first &&
1199*993b0882SAndroid Build Coastguard Worker             prioritize_longest_annotation_) {
1200*993b0882SAndroid Build Coastguard Worker           return scores_lengths[i].second > scores_lengths[j].second;
1201*993b0882SAndroid Build Coastguard Worker         }
1202*993b0882SAndroid Build Coastguard Worker         return scores_lengths[i].first > scores_lengths[j].first;
1203*993b0882SAndroid Build Coastguard Worker       });
1204*993b0882SAndroid Build Coastguard Worker 
1205*993b0882SAndroid Build Coastguard Worker   // Here we keep a set of indices that were chosen, per-source, to enable
1206*993b0882SAndroid Build Coastguard Worker   // effective computation.
1207*993b0882SAndroid Build Coastguard Worker   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1208*993b0882SAndroid Build Coastguard Worker       chosen_indices_for_source_map;
1209*993b0882SAndroid Build Coastguard Worker 
1210*993b0882SAndroid Build Coastguard Worker   // Greedily place the candidates if they don't conflict with the already
1211*993b0882SAndroid Build Coastguard Worker   // placed ones.
1212*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < conflicting_indices.size(); ++i) {
1213*993b0882SAndroid Build Coastguard Worker     const int considered_candidate = conflicting_indices[i];
1214*993b0882SAndroid Build Coastguard Worker 
1215*993b0882SAndroid Build Coastguard Worker     // See if there is a conflict between the candidate and all already placed
1216*993b0882SAndroid Build Coastguard Worker     // candidates.
1217*993b0882SAndroid Build Coastguard Worker     bool conflict = false;
1218*993b0882SAndroid Build Coastguard Worker     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1219*993b0882SAndroid Build Coastguard Worker     for (auto& source_set_pair : chosen_indices_for_source_map) {
1220*993b0882SAndroid Build Coastguard Worker       if (source_set_pair.first == candidates[considered_candidate].source) {
1221*993b0882SAndroid Build Coastguard Worker         chosen_indices_for_source_ptr = &source_set_pair.second;
1222*993b0882SAndroid Build Coastguard Worker       }
1223*993b0882SAndroid Build Coastguard Worker 
1224*993b0882SAndroid Build Coastguard Worker       const bool needs_conflict_resolution =
1225*993b0882SAndroid Build Coastguard Worker           options.annotation_usecase ==
1226*993b0882SAndroid Build Coastguard Worker               AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1227*993b0882SAndroid Build Coastguard Worker           (options.annotation_usecase ==
1228*993b0882SAndroid Build Coastguard Worker                AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1229*993b0882SAndroid Build Coastguard Worker            do_conflict_resolution_in_raw_mode_);
1230*993b0882SAndroid Build Coastguard Worker       if (needs_conflict_resolution &&
1231*993b0882SAndroid Build Coastguard Worker           DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1232*993b0882SAndroid Build Coastguard Worker                             candidates[considered_candidate].source) &&
1233*993b0882SAndroid Build Coastguard Worker           DoesCandidateConflict(considered_candidate, candidates,
1234*993b0882SAndroid Build Coastguard Worker                                 source_set_pair.second)) {
1235*993b0882SAndroid Build Coastguard Worker         conflict = true;
1236*993b0882SAndroid Build Coastguard Worker         break;
1237*993b0882SAndroid Build Coastguard Worker       }
1238*993b0882SAndroid Build Coastguard Worker     }
1239*993b0882SAndroid Build Coastguard Worker 
1240*993b0882SAndroid Build Coastguard Worker     // Skip the candidate if a conflict was found.
1241*993b0882SAndroid Build Coastguard Worker     if (conflict) {
1242*993b0882SAndroid Build Coastguard Worker       continue;
1243*993b0882SAndroid Build Coastguard Worker     }
1244*993b0882SAndroid Build Coastguard Worker 
1245*993b0882SAndroid Build Coastguard Worker     // If the set of indices for the current source doesn't exist yet,
1246*993b0882SAndroid Build Coastguard Worker     // initialize it.
1247*993b0882SAndroid Build Coastguard Worker     if (chosen_indices_for_source_ptr == nullptr) {
1248*993b0882SAndroid Build Coastguard Worker       SortedIntSet new_set([&candidates](int a, int b) {
1249*993b0882SAndroid Build Coastguard Worker         return candidates[a].span.first < candidates[b].span.first;
1250*993b0882SAndroid Build Coastguard Worker       });
1251*993b0882SAndroid Build Coastguard Worker       chosen_indices_for_source_map[candidates[considered_candidate].source] =
1252*993b0882SAndroid Build Coastguard Worker           std::move(new_set);
1253*993b0882SAndroid Build Coastguard Worker       chosen_indices_for_source_ptr =
1254*993b0882SAndroid Build Coastguard Worker           &chosen_indices_for_source_map[candidates[considered_candidate]
1255*993b0882SAndroid Build Coastguard Worker                                              .source];
1256*993b0882SAndroid Build Coastguard Worker     }
1257*993b0882SAndroid Build Coastguard Worker 
1258*993b0882SAndroid Build Coastguard Worker     // Place the candidate to the output and to the per-source conflict set.
1259*993b0882SAndroid Build Coastguard Worker     chosen_indices->push_back(considered_candidate);
1260*993b0882SAndroid Build Coastguard Worker     chosen_indices_for_source_ptr->insert(considered_candidate);
1261*993b0882SAndroid Build Coastguard Worker   }
1262*993b0882SAndroid Build Coastguard Worker 
1263*993b0882SAndroid Build Coastguard Worker   std::stable_sort(chosen_indices->begin(), chosen_indices->end());
1264*993b0882SAndroid Build Coastguard Worker 
1265*993b0882SAndroid Build Coastguard Worker   return true;
1266*993b0882SAndroid Build Coastguard Worker }
1267*993b0882SAndroid Build Coastguard Worker 
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1268*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelSuggestSelection(
1269*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1270*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags,
1271*993b0882SAndroid Build Coastguard Worker     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1272*993b0882SAndroid Build Coastguard Worker     std::vector<AnnotatedSpan>* result) const {
1273*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() == nullptr ||
1274*993b0882SAndroid Build Coastguard Worker       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1275*993b0882SAndroid Build Coastguard Worker     return true;
1276*993b0882SAndroid Build Coastguard Worker   }
1277*993b0882SAndroid Build Coastguard Worker 
1278*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1279*993b0882SAndroid Build Coastguard Worker                                     ml_model_triggering_locales_,
1280*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
1281*993b0882SAndroid Build Coastguard Worker     return true;
1282*993b0882SAndroid Build Coastguard Worker   }
1283*993b0882SAndroid Build Coastguard Worker 
1284*993b0882SAndroid Build Coastguard Worker   int click_pos;
1285*993b0882SAndroid Build Coastguard Worker   *tokens = selection_feature_processor_->Tokenize(context_unicode);
1286*993b0882SAndroid Build Coastguard Worker   const auto [click_begin, click_end] =
1287*993b0882SAndroid Build Coastguard Worker       CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1288*993b0882SAndroid Build Coastguard Worker   selection_feature_processor_->RetokenizeAndFindClick(
1289*993b0882SAndroid Build Coastguard Worker       context_unicode, click_begin, click_end, click_indices,
1290*993b0882SAndroid Build Coastguard Worker       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1291*993b0882SAndroid Build Coastguard Worker       tokens, &click_pos);
1292*993b0882SAndroid Build Coastguard Worker   if (click_pos == kInvalidIndex) {
1293*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Could not calculate the click position.";
1294*993b0882SAndroid Build Coastguard Worker     return false;
1295*993b0882SAndroid Build Coastguard Worker   }
1296*993b0882SAndroid Build Coastguard Worker 
1297*993b0882SAndroid Build Coastguard Worker   const int symmetry_context_size =
1298*993b0882SAndroid Build Coastguard Worker       model_->selection_options()->symmetry_context_size();
1299*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1300*993b0882SAndroid Build Coastguard Worker       bounds_sensitive_features = selection_feature_processor_->GetOptions()
1301*993b0882SAndroid Build Coastguard Worker                                       ->bounds_sensitive_features();
1302*993b0882SAndroid Build Coastguard Worker 
1303*993b0882SAndroid Build Coastguard Worker   // The symmetry context span is the clicked token with symmetry_context_size
1304*993b0882SAndroid Build Coastguard Worker   // tokens on either side.
1305*993b0882SAndroid Build Coastguard Worker   const TokenSpan symmetry_context_span =
1306*993b0882SAndroid Build Coastguard Worker       IntersectTokenSpans(TokenSpan(click_pos).Expand(
1307*993b0882SAndroid Build Coastguard Worker                               /*num_tokens_left=*/symmetry_context_size,
1308*993b0882SAndroid Build Coastguard Worker                               /*num_tokens_right=*/symmetry_context_size),
1309*993b0882SAndroid Build Coastguard Worker                           AllOf(*tokens));
1310*993b0882SAndroid Build Coastguard Worker 
1311*993b0882SAndroid Build Coastguard Worker   // Compute the extraction span based on the model type.
1312*993b0882SAndroid Build Coastguard Worker   TokenSpan extraction_span;
1313*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1314*993b0882SAndroid Build Coastguard Worker     // The extraction span is the symmetry context span expanded to include
1315*993b0882SAndroid Build Coastguard Worker     // max_selection_span tokens on either side, which is how far a selection
1316*993b0882SAndroid Build Coastguard Worker     // can stretch from the click, plus a relevant number of tokens outside of
1317*993b0882SAndroid Build Coastguard Worker     // the bounds of the selection.
1318*993b0882SAndroid Build Coastguard Worker     const int max_selection_span =
1319*993b0882SAndroid Build Coastguard Worker         selection_feature_processor_->GetOptions()->max_selection_span();
1320*993b0882SAndroid Build Coastguard Worker     extraction_span = symmetry_context_span.Expand(
1321*993b0882SAndroid Build Coastguard Worker         /*num_tokens_left=*/max_selection_span +
1322*993b0882SAndroid Build Coastguard Worker             bounds_sensitive_features->num_tokens_before(),
1323*993b0882SAndroid Build Coastguard Worker         /*num_tokens_right=*/max_selection_span +
1324*993b0882SAndroid Build Coastguard Worker             bounds_sensitive_features->num_tokens_after());
1325*993b0882SAndroid Build Coastguard Worker   } else {
1326*993b0882SAndroid Build Coastguard Worker     // The extraction span is the symmetry context span expanded to include
1327*993b0882SAndroid Build Coastguard Worker     // context_size tokens on either side.
1328*993b0882SAndroid Build Coastguard Worker     const int context_size =
1329*993b0882SAndroid Build Coastguard Worker         selection_feature_processor_->GetOptions()->context_size();
1330*993b0882SAndroid Build Coastguard Worker     extraction_span = symmetry_context_span.Expand(
1331*993b0882SAndroid Build Coastguard Worker         /*num_tokens_left=*/context_size,
1332*993b0882SAndroid Build Coastguard Worker         /*num_tokens_right=*/context_size);
1333*993b0882SAndroid Build Coastguard Worker   }
1334*993b0882SAndroid Build Coastguard Worker   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1335*993b0882SAndroid Build Coastguard Worker 
1336*993b0882SAndroid Build Coastguard Worker   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1337*993b0882SAndroid Build Coastguard Worker           *tokens, extraction_span)) {
1338*993b0882SAndroid Build Coastguard Worker     return true;
1339*993b0882SAndroid Build Coastguard Worker   }
1340*993b0882SAndroid Build Coastguard Worker 
1341*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<CachedFeatures> cached_features;
1342*993b0882SAndroid Build Coastguard Worker   if (!selection_feature_processor_->ExtractFeatures(
1343*993b0882SAndroid Build Coastguard Worker           *tokens, extraction_span,
1344*993b0882SAndroid Build Coastguard Worker           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1345*993b0882SAndroid Build Coastguard Worker           embedding_executor_.get(),
1346*993b0882SAndroid Build Coastguard Worker           /*embedding_cache=*/nullptr,
1347*993b0882SAndroid Build Coastguard Worker           selection_feature_processor_->EmbeddingSize() +
1348*993b0882SAndroid Build Coastguard Worker               selection_feature_processor_->DenseFeaturesCount(),
1349*993b0882SAndroid Build Coastguard Worker           &cached_features)) {
1350*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not extract features.";
1351*993b0882SAndroid Build Coastguard Worker     return false;
1352*993b0882SAndroid Build Coastguard Worker   }
1353*993b0882SAndroid Build Coastguard Worker 
1354*993b0882SAndroid Build Coastguard Worker   // Produce selection model candidates.
1355*993b0882SAndroid Build Coastguard Worker   std::vector<TokenSpan> chunks;
1356*993b0882SAndroid Build Coastguard Worker   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1357*993b0882SAndroid Build Coastguard Worker                   interpreter_manager->SelectionInterpreter(), *cached_features,
1358*993b0882SAndroid Build Coastguard Worker                   &chunks)) {
1359*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not chunk.";
1360*993b0882SAndroid Build Coastguard Worker     return false;
1361*993b0882SAndroid Build Coastguard Worker   }
1362*993b0882SAndroid Build Coastguard Worker 
1363*993b0882SAndroid Build Coastguard Worker   for (const TokenSpan& chunk : chunks) {
1364*993b0882SAndroid Build Coastguard Worker     AnnotatedSpan candidate;
1365*993b0882SAndroid Build Coastguard Worker     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1366*993b0882SAndroid Build Coastguard Worker         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1367*993b0882SAndroid Build Coastguard Worker     if (model_->selection_options()->strip_unpaired_brackets()) {
1368*993b0882SAndroid Build Coastguard Worker       candidate.span =
1369*993b0882SAndroid Build Coastguard Worker           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1370*993b0882SAndroid Build Coastguard Worker     }
1371*993b0882SAndroid Build Coastguard Worker 
1372*993b0882SAndroid Build Coastguard Worker     // Only output non-empty spans.
1373*993b0882SAndroid Build Coastguard Worker     if (candidate.span.first != candidate.span.second) {
1374*993b0882SAndroid Build Coastguard Worker       result->push_back(candidate);
1375*993b0882SAndroid Build Coastguard Worker     }
1376*993b0882SAndroid Build Coastguard Worker   }
1377*993b0882SAndroid Build Coastguard Worker   return true;
1378*993b0882SAndroid Build Coastguard Worker }
1379*993b0882SAndroid Build Coastguard Worker 
1380*993b0882SAndroid Build Coastguard Worker namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1381*993b0882SAndroid Build Coastguard Worker std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1382*993b0882SAndroid Build Coastguard Worker                                     const CodepointSpan& selection_indices,
1383*993b0882SAndroid Build Coastguard Worker                                     TokenSpan tokens_around_selection_to_copy) {
1384*993b0882SAndroid Build Coastguard Worker   const auto first_selection_token = std::upper_bound(
1385*993b0882SAndroid Build Coastguard Worker       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1386*993b0882SAndroid Build Coastguard Worker       [](int selection_start, const Token& token) {
1387*993b0882SAndroid Build Coastguard Worker         return selection_start < token.end;
1388*993b0882SAndroid Build Coastguard Worker       });
1389*993b0882SAndroid Build Coastguard Worker   const auto last_selection_token = std::lower_bound(
1390*993b0882SAndroid Build Coastguard Worker       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1391*993b0882SAndroid Build Coastguard Worker       [](const Token& token, int selection_end) {
1392*993b0882SAndroid Build Coastguard Worker         return token.start < selection_end;
1393*993b0882SAndroid Build Coastguard Worker       });
1394*993b0882SAndroid Build Coastguard Worker 
1395*993b0882SAndroid Build Coastguard Worker   const int64 first_token = std::max(
1396*993b0882SAndroid Build Coastguard Worker       static_cast<int64>(0),
1397*993b0882SAndroid Build Coastguard Worker       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1398*993b0882SAndroid Build Coastguard Worker                          tokens_around_selection_to_copy.first));
1399*993b0882SAndroid Build Coastguard Worker   const int64 last_token = std::min(
1400*993b0882SAndroid Build Coastguard Worker       static_cast<int64>(cached_tokens.size()),
1401*993b0882SAndroid Build Coastguard Worker       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1402*993b0882SAndroid Build Coastguard Worker                          tokens_around_selection_to_copy.second));
1403*993b0882SAndroid Build Coastguard Worker 
1404*993b0882SAndroid Build Coastguard Worker   std::vector<Token> tokens;
1405*993b0882SAndroid Build Coastguard Worker   tokens.reserve(last_token - first_token);
1406*993b0882SAndroid Build Coastguard Worker   for (int i = first_token; i < last_token; ++i) {
1407*993b0882SAndroid Build Coastguard Worker     tokens.push_back(cached_tokens[i]);
1408*993b0882SAndroid Build Coastguard Worker   }
1409*993b0882SAndroid Build Coastguard Worker   return tokens;
1410*993b0882SAndroid Build Coastguard Worker }
1411*993b0882SAndroid Build Coastguard Worker }  // namespace internal
1412*993b0882SAndroid Build Coastguard Worker 
ClassifyTextUpperBoundNeededTokens() const1413*993b0882SAndroid Build Coastguard Worker TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1414*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1415*993b0882SAndroid Build Coastguard Worker       bounds_sensitive_features =
1416*993b0882SAndroid Build Coastguard Worker           classification_feature_processor_->GetOptions()
1417*993b0882SAndroid Build Coastguard Worker               ->bounds_sensitive_features();
1418*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1419*993b0882SAndroid Build Coastguard Worker     // The extraction span is the selection span expanded to include a relevant
1420*993b0882SAndroid Build Coastguard Worker     // number of tokens outside of the bounds of the selection.
1421*993b0882SAndroid Build Coastguard Worker     return {bounds_sensitive_features->num_tokens_before(),
1422*993b0882SAndroid Build Coastguard Worker             bounds_sensitive_features->num_tokens_after()};
1423*993b0882SAndroid Build Coastguard Worker   } else {
1424*993b0882SAndroid Build Coastguard Worker     // The extraction span is the clicked token with context_size tokens on
1425*993b0882SAndroid Build Coastguard Worker     // either side.
1426*993b0882SAndroid Build Coastguard Worker     const int context_size =
1427*993b0882SAndroid Build Coastguard Worker         selection_feature_processor_->GetOptions()->context_size();
1428*993b0882SAndroid Build Coastguard Worker     return {context_size, context_size};
1429*993b0882SAndroid Build Coastguard Worker   }
1430*993b0882SAndroid Build Coastguard Worker }
1431*993b0882SAndroid Build Coastguard Worker 
1432*993b0882SAndroid Build Coastguard Worker namespace {
1433*993b0882SAndroid Build Coastguard Worker // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1434*993b0882SAndroid Build Coastguard Worker void SortClassificationResults(
1435*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>* classification_results) {
1436*993b0882SAndroid Build Coastguard Worker   std::stable_sort(
1437*993b0882SAndroid Build Coastguard Worker       classification_results->begin(), classification_results->end(),
1438*993b0882SAndroid Build Coastguard Worker       [](const ClassificationResult& a, const ClassificationResult& b) {
1439*993b0882SAndroid Build Coastguard Worker         return a.score > b.score;
1440*993b0882SAndroid Build Coastguard Worker       });
1441*993b0882SAndroid Build Coastguard Worker }
1442*993b0882SAndroid Build Coastguard Worker }  // namespace
1443*993b0882SAndroid Build Coastguard Worker 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1444*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelClassifyText(
1445*993b0882SAndroid Build Coastguard Worker     const std::string& context, const std::vector<Token>& cached_tokens,
1446*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags,
1447*993b0882SAndroid Build Coastguard Worker     const CodepointSpan& selection_indices, const BaseOptions& options,
1448*993b0882SAndroid Build Coastguard Worker     InterpreterManager* interpreter_manager,
1449*993b0882SAndroid Build Coastguard Worker     FeatureProcessor::EmbeddingCache* embedding_cache,
1450*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>* classification_results,
1451*993b0882SAndroid Build Coastguard Worker     std::vector<Token>* tokens) const {
1452*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode =
1453*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false);
1454*993b0882SAndroid Build Coastguard Worker   const auto [span_begin, span_end] =
1455*993b0882SAndroid Build Coastguard Worker       CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1456*993b0882SAndroid Build Coastguard Worker   return ModelClassifyText(context_unicode, cached_tokens,
1457*993b0882SAndroid Build Coastguard Worker                            detected_text_language_tags, span_begin, span_end,
1458*993b0882SAndroid Build Coastguard Worker                            /*line=*/nullptr, selection_indices, options,
1459*993b0882SAndroid Build Coastguard Worker                            interpreter_manager, embedding_cache,
1460*993b0882SAndroid Build Coastguard Worker                            classification_results, tokens);
1461*993b0882SAndroid Build Coastguard Worker }
1462*993b0882SAndroid Build Coastguard Worker 
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1463*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelClassifyText(
1464*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1465*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags,
1466*993b0882SAndroid Build Coastguard Worker     const UnicodeText::const_iterator& span_begin,
1467*993b0882SAndroid Build Coastguard Worker     const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1468*993b0882SAndroid Build Coastguard Worker     const CodepointSpan& selection_indices, const BaseOptions& options,
1469*993b0882SAndroid Build Coastguard Worker     InterpreterManager* interpreter_manager,
1470*993b0882SAndroid Build Coastguard Worker     FeatureProcessor::EmbeddingCache* embedding_cache,
1471*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>* classification_results,
1472*993b0882SAndroid Build Coastguard Worker     std::vector<Token>* tokens) const {
1473*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() == nullptr ||
1474*993b0882SAndroid Build Coastguard Worker       !(model_->triggering_options()->enabled_modes() &
1475*993b0882SAndroid Build Coastguard Worker         ModeFlag_CLASSIFICATION)) {
1476*993b0882SAndroid Build Coastguard Worker     return true;
1477*993b0882SAndroid Build Coastguard Worker   }
1478*993b0882SAndroid Build Coastguard Worker 
1479*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1480*993b0882SAndroid Build Coastguard Worker                                     ml_model_triggering_locales_,
1481*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
1482*993b0882SAndroid Build Coastguard Worker     return true;
1483*993b0882SAndroid Build Coastguard Worker   }
1484*993b0882SAndroid Build Coastguard Worker 
1485*993b0882SAndroid Build Coastguard Worker   std::vector<Token> local_tokens;
1486*993b0882SAndroid Build Coastguard Worker   if (tokens == nullptr) {
1487*993b0882SAndroid Build Coastguard Worker     tokens = &local_tokens;
1488*993b0882SAndroid Build Coastguard Worker   }
1489*993b0882SAndroid Build Coastguard Worker 
1490*993b0882SAndroid Build Coastguard Worker   if (cached_tokens.empty()) {
1491*993b0882SAndroid Build Coastguard Worker     *tokens = classification_feature_processor_->Tokenize(context_unicode);
1492*993b0882SAndroid Build Coastguard Worker   } else {
1493*993b0882SAndroid Build Coastguard Worker     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1494*993b0882SAndroid Build Coastguard Worker                                          ClassifyTextUpperBoundNeededTokens());
1495*993b0882SAndroid Build Coastguard Worker   }
1496*993b0882SAndroid Build Coastguard Worker 
1497*993b0882SAndroid Build Coastguard Worker   int click_pos;
1498*993b0882SAndroid Build Coastguard Worker   classification_feature_processor_->RetokenizeAndFindClick(
1499*993b0882SAndroid Build Coastguard Worker       context_unicode, span_begin, span_end, selection_indices,
1500*993b0882SAndroid Build Coastguard Worker       classification_feature_processor_->GetOptions()
1501*993b0882SAndroid Build Coastguard Worker           ->only_use_line_with_click(),
1502*993b0882SAndroid Build Coastguard Worker       tokens, &click_pos);
1503*993b0882SAndroid Build Coastguard Worker   const TokenSpan selection_token_span =
1504*993b0882SAndroid Build Coastguard Worker       CodepointSpanToTokenSpan(*tokens, selection_indices);
1505*993b0882SAndroid Build Coastguard Worker   const int selection_num_tokens = selection_token_span.Size();
1506*993b0882SAndroid Build Coastguard Worker   if (model_->classification_options()->max_num_tokens() > 0 &&
1507*993b0882SAndroid Build Coastguard Worker       model_->classification_options()->max_num_tokens() <
1508*993b0882SAndroid Build Coastguard Worker           selection_num_tokens) {
1509*993b0882SAndroid Build Coastguard Worker     *classification_results = {{Collections::Other(), 1.0}};
1510*993b0882SAndroid Build Coastguard Worker     return true;
1511*993b0882SAndroid Build Coastguard Worker   }
1512*993b0882SAndroid Build Coastguard Worker 
1513*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1514*993b0882SAndroid Build Coastguard Worker       bounds_sensitive_features =
1515*993b0882SAndroid Build Coastguard Worker           classification_feature_processor_->GetOptions()
1516*993b0882SAndroid Build Coastguard Worker               ->bounds_sensitive_features();
1517*993b0882SAndroid Build Coastguard Worker   if (selection_token_span.first == kInvalidIndex ||
1518*993b0882SAndroid Build Coastguard Worker       selection_token_span.second == kInvalidIndex) {
1519*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not determine span.";
1520*993b0882SAndroid Build Coastguard Worker     return false;
1521*993b0882SAndroid Build Coastguard Worker   }
1522*993b0882SAndroid Build Coastguard Worker 
1523*993b0882SAndroid Build Coastguard Worker   // Compute the extraction span based on the model type.
1524*993b0882SAndroid Build Coastguard Worker   TokenSpan extraction_span;
1525*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1526*993b0882SAndroid Build Coastguard Worker     // The extraction span is the selection span expanded to include a relevant
1527*993b0882SAndroid Build Coastguard Worker     // number of tokens outside of the bounds of the selection.
1528*993b0882SAndroid Build Coastguard Worker     extraction_span = selection_token_span.Expand(
1529*993b0882SAndroid Build Coastguard Worker         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1530*993b0882SAndroid Build Coastguard Worker         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1531*993b0882SAndroid Build Coastguard Worker   } else {
1532*993b0882SAndroid Build Coastguard Worker     if (click_pos == kInvalidIndex) {
1533*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Couldn't choose a click position.";
1534*993b0882SAndroid Build Coastguard Worker       return false;
1535*993b0882SAndroid Build Coastguard Worker     }
1536*993b0882SAndroid Build Coastguard Worker     // The extraction span is the clicked token with context_size tokens on
1537*993b0882SAndroid Build Coastguard Worker     // either side.
1538*993b0882SAndroid Build Coastguard Worker     const int context_size =
1539*993b0882SAndroid Build Coastguard Worker         classification_feature_processor_->GetOptions()->context_size();
1540*993b0882SAndroid Build Coastguard Worker     extraction_span = TokenSpan(click_pos).Expand(
1541*993b0882SAndroid Build Coastguard Worker         /*num_tokens_left=*/context_size,
1542*993b0882SAndroid Build Coastguard Worker         /*num_tokens_right=*/context_size);
1543*993b0882SAndroid Build Coastguard Worker   }
1544*993b0882SAndroid Build Coastguard Worker   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1545*993b0882SAndroid Build Coastguard Worker 
1546*993b0882SAndroid Build Coastguard Worker   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1547*993b0882SAndroid Build Coastguard Worker           *tokens, extraction_span)) {
1548*993b0882SAndroid Build Coastguard Worker     *classification_results = {{Collections::Other(), 1.0}};
1549*993b0882SAndroid Build Coastguard Worker     return true;
1550*993b0882SAndroid Build Coastguard Worker   }
1551*993b0882SAndroid Build Coastguard Worker 
1552*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<CachedFeatures> cached_features;
1553*993b0882SAndroid Build Coastguard Worker   if (!classification_feature_processor_->ExtractFeatures(
1554*993b0882SAndroid Build Coastguard Worker           *tokens, extraction_span, selection_indices,
1555*993b0882SAndroid Build Coastguard Worker           embedding_executor_.get(), embedding_cache,
1556*993b0882SAndroid Build Coastguard Worker           classification_feature_processor_->EmbeddingSize() +
1557*993b0882SAndroid Build Coastguard Worker               classification_feature_processor_->DenseFeaturesCount(),
1558*993b0882SAndroid Build Coastguard Worker           &cached_features)) {
1559*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not extract features.";
1560*993b0882SAndroid Build Coastguard Worker     return false;
1561*993b0882SAndroid Build Coastguard Worker   }
1562*993b0882SAndroid Build Coastguard Worker 
1563*993b0882SAndroid Build Coastguard Worker   std::vector<float> features;
1564*993b0882SAndroid Build Coastguard Worker   features.reserve(cached_features->OutputFeaturesSize());
1565*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1566*993b0882SAndroid Build Coastguard Worker     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1567*993b0882SAndroid Build Coastguard Worker                                                           &features);
1568*993b0882SAndroid Build Coastguard Worker   } else {
1569*993b0882SAndroid Build Coastguard Worker     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1570*993b0882SAndroid Build Coastguard Worker   }
1571*993b0882SAndroid Build Coastguard Worker 
1572*993b0882SAndroid Build Coastguard Worker   TensorView<float> logits = classification_executor_->ComputeLogits(
1573*993b0882SAndroid Build Coastguard Worker       TensorView<float>(features.data(),
1574*993b0882SAndroid Build Coastguard Worker                         {1, static_cast<int>(features.size())}),
1575*993b0882SAndroid Build Coastguard Worker       interpreter_manager->ClassificationInterpreter());
1576*993b0882SAndroid Build Coastguard Worker   if (!logits.is_valid()) {
1577*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Couldn't compute logits.";
1578*993b0882SAndroid Build Coastguard Worker     return false;
1579*993b0882SAndroid Build Coastguard Worker   }
1580*993b0882SAndroid Build Coastguard Worker 
1581*993b0882SAndroid Build Coastguard Worker   if (logits.dims() != 2 || logits.dim(0) != 1 ||
1582*993b0882SAndroid Build Coastguard Worker       logits.dim(1) != classification_feature_processor_->NumCollections()) {
1583*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Mismatching output";
1584*993b0882SAndroid Build Coastguard Worker     return false;
1585*993b0882SAndroid Build Coastguard Worker   }
1586*993b0882SAndroid Build Coastguard Worker 
1587*993b0882SAndroid Build Coastguard Worker   const std::vector<float> scores =
1588*993b0882SAndroid Build Coastguard Worker       ComputeSoftmax(logits.data(), logits.dim(1));
1589*993b0882SAndroid Build Coastguard Worker 
1590*993b0882SAndroid Build Coastguard Worker   if (scores.empty()) {
1591*993b0882SAndroid Build Coastguard Worker     *classification_results = {{Collections::Other(), 1.0}};
1592*993b0882SAndroid Build Coastguard Worker     return true;
1593*993b0882SAndroid Build Coastguard Worker   }
1594*993b0882SAndroid Build Coastguard Worker 
1595*993b0882SAndroid Build Coastguard Worker   const int best_score_index =
1596*993b0882SAndroid Build Coastguard Worker       std::max_element(scores.begin(), scores.end()) - scores.begin();
1597*993b0882SAndroid Build Coastguard Worker   const std::string top_collection =
1598*993b0882SAndroid Build Coastguard Worker       classification_feature_processor_->LabelToCollection(best_score_index);
1599*993b0882SAndroid Build Coastguard Worker 
1600*993b0882SAndroid Build Coastguard Worker   // Sanity checks.
1601*993b0882SAndroid Build Coastguard Worker   if (top_collection == Collections::Phone()) {
1602*993b0882SAndroid Build Coastguard Worker     const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1603*993b0882SAndroid Build Coastguard Worker     if (digit_count <
1604*993b0882SAndroid Build Coastguard Worker             model_->classification_options()->phone_min_num_digits() ||
1605*993b0882SAndroid Build Coastguard Worker         digit_count >
1606*993b0882SAndroid Build Coastguard Worker             model_->classification_options()->phone_max_num_digits()) {
1607*993b0882SAndroid Build Coastguard Worker       *classification_results = {{Collections::Other(), 1.0}};
1608*993b0882SAndroid Build Coastguard Worker       return true;
1609*993b0882SAndroid Build Coastguard Worker     }
1610*993b0882SAndroid Build Coastguard Worker   } else if (top_collection == Collections::Address()) {
1611*993b0882SAndroid Build Coastguard Worker     if (selection_num_tokens <
1612*993b0882SAndroid Build Coastguard Worker         model_->classification_options()->address_min_num_tokens()) {
1613*993b0882SAndroid Build Coastguard Worker       *classification_results = {{Collections::Other(), 1.0}};
1614*993b0882SAndroid Build Coastguard Worker       return true;
1615*993b0882SAndroid Build Coastguard Worker     }
1616*993b0882SAndroid Build Coastguard Worker   } else if (top_collection == Collections::Dictionary()) {
1617*993b0882SAndroid Build Coastguard Worker     if ((options.use_vocab_annotator && vocab_annotator_) ||
1618*993b0882SAndroid Build Coastguard Worker         !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1619*993b0882SAndroid Build Coastguard Worker                                       dictionary_locales_,
1620*993b0882SAndroid Build Coastguard Worker                                       /*default_value=*/false)) {
1621*993b0882SAndroid Build Coastguard Worker       *classification_results = {{Collections::Other(), 1.0}};
1622*993b0882SAndroid Build Coastguard Worker       return true;
1623*993b0882SAndroid Build Coastguard Worker     }
1624*993b0882SAndroid Build Coastguard Worker   }
1625*993b0882SAndroid Build Coastguard Worker   *classification_results = {{top_collection, /*arg_score=*/1.0,
1626*993b0882SAndroid Build Coastguard Worker                               /*arg_priority_score=*/scores[best_score_index]}};
1627*993b0882SAndroid Build Coastguard Worker 
1628*993b0882SAndroid Build Coastguard Worker   // For some entities, we might want to clamp the priority score, for better
1629*993b0882SAndroid Build Coastguard Worker   // conflict resolution between entities.
1630*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() != nullptr &&
1631*993b0882SAndroid Build Coastguard Worker       model_->triggering_options()->collection_to_priority() != nullptr) {
1632*993b0882SAndroid Build Coastguard Worker     if (auto entry =
1633*993b0882SAndroid Build Coastguard Worker             model_->triggering_options()->collection_to_priority()->LookupByKey(
1634*993b0882SAndroid Build Coastguard Worker                 top_collection.c_str())) {
1635*993b0882SAndroid Build Coastguard Worker       (*classification_results)[0].priority_score *= entry->value();
1636*993b0882SAndroid Build Coastguard Worker     }
1637*993b0882SAndroid Build Coastguard Worker   }
1638*993b0882SAndroid Build Coastguard Worker   return true;
1639*993b0882SAndroid Build Coastguard Worker }
1640*993b0882SAndroid Build Coastguard Worker 
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1641*993b0882SAndroid Build Coastguard Worker bool Annotator::RegexClassifyText(
1642*993b0882SAndroid Build Coastguard Worker     const std::string& context, const CodepointSpan& selection_indices,
1643*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>* classification_result) const {
1644*993b0882SAndroid Build Coastguard Worker   const std::string selection_text =
1645*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false)
1646*993b0882SAndroid Build Coastguard Worker           .UTF8Substring(selection_indices.first, selection_indices.second);
1647*993b0882SAndroid Build Coastguard Worker   const UnicodeText selection_text_unicode(
1648*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1649*993b0882SAndroid Build Coastguard Worker 
1650*993b0882SAndroid Build Coastguard Worker   // Check whether any of the regular expressions match.
1651*993b0882SAndroid Build Coastguard Worker   for (const int pattern_id : classification_regex_patterns_) {
1652*993b0882SAndroid Build Coastguard Worker     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1653*993b0882SAndroid Build Coastguard Worker     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1654*993b0882SAndroid Build Coastguard Worker         regex_pattern.pattern->Matcher(selection_text_unicode);
1655*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
1656*993b0882SAndroid Build Coastguard Worker     bool matches;
1657*993b0882SAndroid Build Coastguard Worker     if (regex_pattern.config->use_approximate_matching()) {
1658*993b0882SAndroid Build Coastguard Worker       matches = matcher->ApproximatelyMatches(&status);
1659*993b0882SAndroid Build Coastguard Worker     } else {
1660*993b0882SAndroid Build Coastguard Worker       matches = matcher->Matches(&status);
1661*993b0882SAndroid Build Coastguard Worker     }
1662*993b0882SAndroid Build Coastguard Worker     if (status != UniLib::RegexMatcher::kNoError) {
1663*993b0882SAndroid Build Coastguard Worker       return false;
1664*993b0882SAndroid Build Coastguard Worker     }
1665*993b0882SAndroid Build Coastguard Worker     if (matches && VerifyRegexMatchCandidate(
1666*993b0882SAndroid Build Coastguard Worker                        context, regex_pattern.config->verification_options(),
1667*993b0882SAndroid Build Coastguard Worker                        selection_text, matcher.get())) {
1668*993b0882SAndroid Build Coastguard Worker       classification_result->push_back(
1669*993b0882SAndroid Build Coastguard Worker           {regex_pattern.config->collection_name()->str(),
1670*993b0882SAndroid Build Coastguard Worker            regex_pattern.config->target_classification_score(),
1671*993b0882SAndroid Build Coastguard Worker            regex_pattern.config->priority_score()});
1672*993b0882SAndroid Build Coastguard Worker       if (!SerializedEntityDataFromRegexMatch(
1673*993b0882SAndroid Build Coastguard Worker               regex_pattern.config, matcher.get(),
1674*993b0882SAndroid Build Coastguard Worker               &classification_result->back().serialized_entity_data)) {
1675*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not get entity data.";
1676*993b0882SAndroid Build Coastguard Worker         return false;
1677*993b0882SAndroid Build Coastguard Worker       }
1678*993b0882SAndroid Build Coastguard Worker     }
1679*993b0882SAndroid Build Coastguard Worker   }
1680*993b0882SAndroid Build Coastguard Worker 
1681*993b0882SAndroid Build Coastguard Worker   return true;
1682*993b0882SAndroid Build Coastguard Worker }
1683*993b0882SAndroid Build Coastguard Worker 
1684*993b0882SAndroid Build Coastguard Worker namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1685*993b0882SAndroid Build Coastguard Worker std::string PickCollectionForDatetime(
1686*993b0882SAndroid Build Coastguard Worker     const DatetimeParseResult& datetime_parse_result) {
1687*993b0882SAndroid Build Coastguard Worker   switch (datetime_parse_result.granularity) {
1688*993b0882SAndroid Build Coastguard Worker     case GRANULARITY_HOUR:
1689*993b0882SAndroid Build Coastguard Worker     case GRANULARITY_MINUTE:
1690*993b0882SAndroid Build Coastguard Worker     case GRANULARITY_SECOND:
1691*993b0882SAndroid Build Coastguard Worker       return Collections::DateTime();
1692*993b0882SAndroid Build Coastguard Worker     default:
1693*993b0882SAndroid Build Coastguard Worker       return Collections::Date();
1694*993b0882SAndroid Build Coastguard Worker   }
1695*993b0882SAndroid Build Coastguard Worker }
1696*993b0882SAndroid Build Coastguard Worker 
1697*993b0882SAndroid Build Coastguard Worker }  // namespace
1698*993b0882SAndroid Build Coastguard Worker 
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1699*993b0882SAndroid Build Coastguard Worker bool Annotator::DatetimeClassifyText(
1700*993b0882SAndroid Build Coastguard Worker     const std::string& context, const CodepointSpan& selection_indices,
1701*993b0882SAndroid Build Coastguard Worker     const ClassificationOptions& options,
1702*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>* classification_results) const {
1703*993b0882SAndroid Build Coastguard Worker   if (!datetime_parser_) {
1704*993b0882SAndroid Build Coastguard Worker     return true;
1705*993b0882SAndroid Build Coastguard Worker   }
1706*993b0882SAndroid Build Coastguard Worker 
1707*993b0882SAndroid Build Coastguard Worker   const std::string selection_text =
1708*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false)
1709*993b0882SAndroid Build Coastguard Worker           .UTF8Substring(selection_indices.first, selection_indices.second);
1710*993b0882SAndroid Build Coastguard Worker 
1711*993b0882SAndroid Build Coastguard Worker   LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1712*993b0882SAndroid Build Coastguard Worker   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1713*993b0882SAndroid Build Coastguard Worker       datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1714*993b0882SAndroid Build Coastguard Worker                               options.reference_timezone, locale_list,
1715*993b0882SAndroid Build Coastguard Worker                               ModeFlag_CLASSIFICATION,
1716*993b0882SAndroid Build Coastguard Worker                               options.annotation_usecase,
1717*993b0882SAndroid Build Coastguard Worker                               /*anchor_start_end=*/true);
1718*993b0882SAndroid Build Coastguard Worker   if (!result_status.ok()) {
1719*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Error during parsing datetime.";
1720*993b0882SAndroid Build Coastguard Worker     return false;
1721*993b0882SAndroid Build Coastguard Worker   }
1722*993b0882SAndroid Build Coastguard Worker 
1723*993b0882SAndroid Build Coastguard Worker   for (const DatetimeParseResultSpan& datetime_span :
1724*993b0882SAndroid Build Coastguard Worker        result_status.ValueOrDie()) {
1725*993b0882SAndroid Build Coastguard Worker     // Only consider the result valid if the selection and extracted datetime
1726*993b0882SAndroid Build Coastguard Worker     // spans exactly match.
1727*993b0882SAndroid Build Coastguard Worker     if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1728*993b0882SAndroid Build Coastguard Worker                       datetime_span.span.second + selection_indices.first) ==
1729*993b0882SAndroid Build Coastguard Worker         selection_indices) {
1730*993b0882SAndroid Build Coastguard Worker       for (const DatetimeParseResult& parse_result : datetime_span.data) {
1731*993b0882SAndroid Build Coastguard Worker         classification_results->emplace_back(
1732*993b0882SAndroid Build Coastguard Worker             PickCollectionForDatetime(parse_result),
1733*993b0882SAndroid Build Coastguard Worker             datetime_span.target_classification_score);
1734*993b0882SAndroid Build Coastguard Worker         classification_results->back().datetime_parse_result = parse_result;
1735*993b0882SAndroid Build Coastguard Worker         classification_results->back().serialized_entity_data =
1736*993b0882SAndroid Build Coastguard Worker             CreateDatetimeSerializedEntityData(parse_result);
1737*993b0882SAndroid Build Coastguard Worker         classification_results->back().priority_score =
1738*993b0882SAndroid Build Coastguard Worker             datetime_span.priority_score;
1739*993b0882SAndroid Build Coastguard Worker       }
1740*993b0882SAndroid Build Coastguard Worker       return true;
1741*993b0882SAndroid Build Coastguard Worker     }
1742*993b0882SAndroid Build Coastguard Worker   }
1743*993b0882SAndroid Build Coastguard Worker   return true;
1744*993b0882SAndroid Build Coastguard Worker }
1745*993b0882SAndroid Build Coastguard Worker 
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1746*993b0882SAndroid Build Coastguard Worker std::vector<ClassificationResult> Annotator::ClassifyText(
1747*993b0882SAndroid Build Coastguard Worker     const std::string& context, const CodepointSpan& selection_indices,
1748*993b0882SAndroid Build Coastguard Worker     const ClassificationOptions& options) const {
1749*993b0882SAndroid Build Coastguard Worker   if (context.size() > std::numeric_limits<int>::max()) {
1750*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1751*993b0882SAndroid Build Coastguard Worker     return {};
1752*993b0882SAndroid Build Coastguard Worker   }
1753*993b0882SAndroid Build Coastguard Worker   if (!initialized_) {
1754*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Not initialized";
1755*993b0882SAndroid Build Coastguard Worker     return {};
1756*993b0882SAndroid Build Coastguard Worker   }
1757*993b0882SAndroid Build Coastguard Worker   if (options.annotation_usecase !=
1758*993b0882SAndroid Build Coastguard Worker       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1759*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING)
1760*993b0882SAndroid Build Coastguard Worker         << "Invoking ClassifyText, which is not supported in RAW mode.";
1761*993b0882SAndroid Build Coastguard Worker     return {};
1762*993b0882SAndroid Build Coastguard Worker   }
1763*993b0882SAndroid Build Coastguard Worker   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1764*993b0882SAndroid Build Coastguard Worker     return {};
1765*993b0882SAndroid Build Coastguard Worker   }
1766*993b0882SAndroid Build Coastguard Worker 
1767*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> detected_text_language_tags;
1768*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(options.detected_text_language_tags,
1769*993b0882SAndroid Build Coastguard Worker                     &detected_text_language_tags)) {
1770*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING)
1771*993b0882SAndroid Build Coastguard Worker         << "Failed to parse the detected_text_language_tags in options: "
1772*993b0882SAndroid Build Coastguard Worker         << options.detected_text_language_tags;
1773*993b0882SAndroid Build Coastguard Worker   }
1774*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1775*993b0882SAndroid Build Coastguard Worker                                     model_triggering_locales_,
1776*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
1777*993b0882SAndroid Build Coastguard Worker     return {};
1778*993b0882SAndroid Build Coastguard Worker   }
1779*993b0882SAndroid Build Coastguard Worker 
1780*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode =
1781*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false);
1782*993b0882SAndroid Build Coastguard Worker 
1783*993b0882SAndroid Build Coastguard Worker   if (!unilib_->IsValidUtf8(context_unicode)) {
1784*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1785*993b0882SAndroid Build Coastguard Worker     return {};
1786*993b0882SAndroid Build Coastguard Worker   }
1787*993b0882SAndroid Build Coastguard Worker 
1788*993b0882SAndroid Build Coastguard Worker   if (!IsValidSpanInput(context_unicode, selection_indices)) {
1789*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1790*993b0882SAndroid Build Coastguard Worker                 << selection_indices.first << " " << selection_indices.second;
1791*993b0882SAndroid Build Coastguard Worker     return {};
1792*993b0882SAndroid Build Coastguard Worker   }
1793*993b0882SAndroid Build Coastguard Worker 
1794*993b0882SAndroid Build Coastguard Worker   // We'll accumulate a list of candidates, and pick the best candidate in the
1795*993b0882SAndroid Build Coastguard Worker   // end.
1796*993b0882SAndroid Build Coastguard Worker   std::vector<AnnotatedSpan> candidates;
1797*993b0882SAndroid Build Coastguard Worker 
1798*993b0882SAndroid Build Coastguard Worker   // Try the knowledge engine.
1799*993b0882SAndroid Build Coastguard Worker   // TODO(b/126579108): Propagate error status.
1800*993b0882SAndroid Build Coastguard Worker   ClassificationResult knowledge_result;
1801*993b0882SAndroid Build Coastguard Worker   if (knowledge_engine_ &&
1802*993b0882SAndroid Build Coastguard Worker       knowledge_engine_
1803*993b0882SAndroid Build Coastguard Worker           ->ClassifyText(context, selection_indices, options.annotation_usecase,
1804*993b0882SAndroid Build Coastguard Worker                          options.location_context, Permissions(),
1805*993b0882SAndroid Build Coastguard Worker                          &knowledge_result)
1806*993b0882SAndroid Build Coastguard Worker           .ok()) {
1807*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {knowledge_result}});
1808*993b0882SAndroid Build Coastguard Worker     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1809*993b0882SAndroid Build Coastguard Worker   }
1810*993b0882SAndroid Build Coastguard Worker 
1811*993b0882SAndroid Build Coastguard Worker   AddContactMetadataToKnowledgeClassificationResults(&candidates);
1812*993b0882SAndroid Build Coastguard Worker 
1813*993b0882SAndroid Build Coastguard Worker   // Try the contact engine.
1814*993b0882SAndroid Build Coastguard Worker   // TODO(b/126579108): Propagate error status.
1815*993b0882SAndroid Build Coastguard Worker   ClassificationResult contact_result;
1816*993b0882SAndroid Build Coastguard Worker   if (contact_engine_ && contact_engine_->ClassifyText(
1817*993b0882SAndroid Build Coastguard Worker                              context, selection_indices, &contact_result)) {
1818*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {contact_result}});
1819*993b0882SAndroid Build Coastguard Worker   }
1820*993b0882SAndroid Build Coastguard Worker 
1821*993b0882SAndroid Build Coastguard Worker   // Try the person name engine.
1822*993b0882SAndroid Build Coastguard Worker   ClassificationResult person_name_result;
1823*993b0882SAndroid Build Coastguard Worker   if (person_name_engine_ &&
1824*993b0882SAndroid Build Coastguard Worker       person_name_engine_->ClassifyText(context, selection_indices,
1825*993b0882SAndroid Build Coastguard Worker                                         &person_name_result)) {
1826*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {person_name_result}});
1827*993b0882SAndroid Build Coastguard Worker     candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1828*993b0882SAndroid Build Coastguard Worker   }
1829*993b0882SAndroid Build Coastguard Worker 
1830*993b0882SAndroid Build Coastguard Worker   // Try the installed app engine.
1831*993b0882SAndroid Build Coastguard Worker   // TODO(b/126579108): Propagate error status.
1832*993b0882SAndroid Build Coastguard Worker   ClassificationResult installed_app_result;
1833*993b0882SAndroid Build Coastguard Worker   if (installed_app_engine_ &&
1834*993b0882SAndroid Build Coastguard Worker       installed_app_engine_->ClassifyText(context, selection_indices,
1835*993b0882SAndroid Build Coastguard Worker                                           &installed_app_result)) {
1836*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {installed_app_result}});
1837*993b0882SAndroid Build Coastguard Worker   }
1838*993b0882SAndroid Build Coastguard Worker 
1839*993b0882SAndroid Build Coastguard Worker   // Try the regular expression models.
1840*993b0882SAndroid Build Coastguard Worker   std::vector<ClassificationResult> regex_results;
1841*993b0882SAndroid Build Coastguard Worker   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1842*993b0882SAndroid Build Coastguard Worker     return {};
1843*993b0882SAndroid Build Coastguard Worker   }
1844*993b0882SAndroid Build Coastguard Worker   for (const ClassificationResult& result : regex_results) {
1845*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {result}});
1846*993b0882SAndroid Build Coastguard Worker   }
1847*993b0882SAndroid Build Coastguard Worker 
1848*993b0882SAndroid Build Coastguard Worker   // Try the date model.
1849*993b0882SAndroid Build Coastguard Worker   //
1850*993b0882SAndroid Build Coastguard Worker   // DatetimeClassifyText only returns the first result, which can however have
1851*993b0882SAndroid Build Coastguard Worker   // more interpretations. They are inserted in the candidates as a single
1852*993b0882SAndroid Build Coastguard Worker   // AnnotatedSpan, so that they get treated together by the conflict resolution
1853*993b0882SAndroid Build Coastguard Worker   // algorithm.
1854*993b0882SAndroid Build Coastguard Worker   std::vector<ClassificationResult> datetime_results;
1855*993b0882SAndroid Build Coastguard Worker   if (!DatetimeClassifyText(context, selection_indices, options,
1856*993b0882SAndroid Build Coastguard Worker                             &datetime_results)) {
1857*993b0882SAndroid Build Coastguard Worker     return {};
1858*993b0882SAndroid Build Coastguard Worker   }
1859*993b0882SAndroid Build Coastguard Worker   if (!datetime_results.empty()) {
1860*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, std::move(datetime_results)});
1861*993b0882SAndroid Build Coastguard Worker     candidates.back().source = AnnotatedSpan::Source::DATETIME;
1862*993b0882SAndroid Build Coastguard Worker   }
1863*993b0882SAndroid Build Coastguard Worker 
1864*993b0882SAndroid Build Coastguard Worker   // Try the number annotator.
1865*993b0882SAndroid Build Coastguard Worker   // TODO(b/126579108): Propagate error status.
1866*993b0882SAndroid Build Coastguard Worker   ClassificationResult number_annotator_result;
1867*993b0882SAndroid Build Coastguard Worker   if (number_annotator_ &&
1868*993b0882SAndroid Build Coastguard Worker       number_annotator_->ClassifyText(context_unicode, selection_indices,
1869*993b0882SAndroid Build Coastguard Worker                                       options.annotation_usecase,
1870*993b0882SAndroid Build Coastguard Worker                                       &number_annotator_result)) {
1871*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {number_annotator_result}});
1872*993b0882SAndroid Build Coastguard Worker   }
1873*993b0882SAndroid Build Coastguard Worker 
1874*993b0882SAndroid Build Coastguard Worker   // Try the duration annotator.
1875*993b0882SAndroid Build Coastguard Worker   ClassificationResult duration_annotator_result;
1876*993b0882SAndroid Build Coastguard Worker   if (duration_annotator_ &&
1877*993b0882SAndroid Build Coastguard Worker       duration_annotator_->ClassifyText(context_unicode, selection_indices,
1878*993b0882SAndroid Build Coastguard Worker                                         options.annotation_usecase,
1879*993b0882SAndroid Build Coastguard Worker                                         &duration_annotator_result)) {
1880*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {duration_annotator_result}});
1881*993b0882SAndroid Build Coastguard Worker     candidates.back().source = AnnotatedSpan::Source::DURATION;
1882*993b0882SAndroid Build Coastguard Worker   }
1883*993b0882SAndroid Build Coastguard Worker 
1884*993b0882SAndroid Build Coastguard Worker   // Try the translate annotator.
1885*993b0882SAndroid Build Coastguard Worker   ClassificationResult translate_annotator_result;
1886*993b0882SAndroid Build Coastguard Worker   if (translate_annotator_ &&
1887*993b0882SAndroid Build Coastguard Worker       translate_annotator_->ClassifyText(context_unicode, selection_indices,
1888*993b0882SAndroid Build Coastguard Worker                                          options.user_familiar_language_tags,
1889*993b0882SAndroid Build Coastguard Worker                                          &translate_annotator_result)) {
1890*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {translate_annotator_result}});
1891*993b0882SAndroid Build Coastguard Worker   }
1892*993b0882SAndroid Build Coastguard Worker 
1893*993b0882SAndroid Build Coastguard Worker   // Try the grammar model.
1894*993b0882SAndroid Build Coastguard Worker   ClassificationResult grammar_annotator_result;
1895*993b0882SAndroid Build Coastguard Worker   if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1896*993b0882SAndroid Build Coastguard Worker                                 detected_text_language_tags, context_unicode,
1897*993b0882SAndroid Build Coastguard Worker                                 selection_indices, &grammar_annotator_result)) {
1898*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {grammar_annotator_result}});
1899*993b0882SAndroid Build Coastguard Worker   }
1900*993b0882SAndroid Build Coastguard Worker 
1901*993b0882SAndroid Build Coastguard Worker   ClassificationResult pod_ner_annotator_result;
1902*993b0882SAndroid Build Coastguard Worker   if (pod_ner_annotator_ && options.use_pod_ner &&
1903*993b0882SAndroid Build Coastguard Worker       pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1904*993b0882SAndroid Build Coastguard Worker                                        &pod_ner_annotator_result)) {
1905*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1906*993b0882SAndroid Build Coastguard Worker   }
1907*993b0882SAndroid Build Coastguard Worker 
1908*993b0882SAndroid Build Coastguard Worker   ClassificationResult vocab_annotator_result;
1909*993b0882SAndroid Build Coastguard Worker   if (vocab_annotator_ && options.use_vocab_annotator &&
1910*993b0882SAndroid Build Coastguard Worker       vocab_annotator_->ClassifyText(
1911*993b0882SAndroid Build Coastguard Worker           context_unicode, selection_indices, detected_text_language_tags,
1912*993b0882SAndroid Build Coastguard Worker           options.trigger_dictionary_on_beginner_words,
1913*993b0882SAndroid Build Coastguard Worker           &vocab_annotator_result)) {
1914*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, {vocab_annotator_result}});
1915*993b0882SAndroid Build Coastguard Worker   }
1916*993b0882SAndroid Build Coastguard Worker 
1917*993b0882SAndroid Build Coastguard Worker   if (experimental_annotator_ &&
1918*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options()->experimental_enabled_modes() &
1919*993b0882SAndroid Build Coastguard Worker        ModeFlag_CLASSIFICATION)) {
1920*993b0882SAndroid Build Coastguard Worker     experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1921*993b0882SAndroid Build Coastguard Worker                                           candidates);
1922*993b0882SAndroid Build Coastguard Worker   }
1923*993b0882SAndroid Build Coastguard Worker 
1924*993b0882SAndroid Build Coastguard Worker   // Try the ML model.
1925*993b0882SAndroid Build Coastguard Worker   //
1926*993b0882SAndroid Build Coastguard Worker   // The output of the model is considered as an exclusive 1-of-N choice. That's
1927*993b0882SAndroid Build Coastguard Worker   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1928*993b0882SAndroid Build Coastguard Worker   // span for each candidate, like e.g. the regex model.
1929*993b0882SAndroid Build Coastguard Worker   InterpreterManager interpreter_manager(selection_executor_.get(),
1930*993b0882SAndroid Build Coastguard Worker                                          classification_executor_.get());
1931*993b0882SAndroid Build Coastguard Worker   std::vector<ClassificationResult> model_results;
1932*993b0882SAndroid Build Coastguard Worker   std::vector<Token> tokens;
1933*993b0882SAndroid Build Coastguard Worker   if (!ModelClassifyText(
1934*993b0882SAndroid Build Coastguard Worker           context, /*cached_tokens=*/{}, detected_text_language_tags,
1935*993b0882SAndroid Build Coastguard Worker           selection_indices, options, &interpreter_manager,
1936*993b0882SAndroid Build Coastguard Worker           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1937*993b0882SAndroid Build Coastguard Worker     return {};
1938*993b0882SAndroid Build Coastguard Worker   }
1939*993b0882SAndroid Build Coastguard Worker   if (!model_results.empty()) {
1940*993b0882SAndroid Build Coastguard Worker     candidates.push_back({selection_indices, std::move(model_results)});
1941*993b0882SAndroid Build Coastguard Worker   }
1942*993b0882SAndroid Build Coastguard Worker 
1943*993b0882SAndroid Build Coastguard Worker   std::vector<int> candidate_indices;
1944*993b0882SAndroid Build Coastguard Worker   if (!ResolveConflicts(candidates, context, tokens,
1945*993b0882SAndroid Build Coastguard Worker                         detected_text_language_tags, options,
1946*993b0882SAndroid Build Coastguard Worker                         &interpreter_manager, &candidate_indices)) {
1947*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1948*993b0882SAndroid Build Coastguard Worker     return {};
1949*993b0882SAndroid Build Coastguard Worker   }
1950*993b0882SAndroid Build Coastguard Worker 
1951*993b0882SAndroid Build Coastguard Worker   std::vector<ClassificationResult> results;
1952*993b0882SAndroid Build Coastguard Worker   for (const int i : candidate_indices) {
1953*993b0882SAndroid Build Coastguard Worker     for (const ClassificationResult& result : candidates[i].classification) {
1954*993b0882SAndroid Build Coastguard Worker       if (!FilteredForClassification(result)) {
1955*993b0882SAndroid Build Coastguard Worker         results.push_back(result);
1956*993b0882SAndroid Build Coastguard Worker       }
1957*993b0882SAndroid Build Coastguard Worker     }
1958*993b0882SAndroid Build Coastguard Worker   }
1959*993b0882SAndroid Build Coastguard Worker 
1960*993b0882SAndroid Build Coastguard Worker   // Sort results according to score.
1961*993b0882SAndroid Build Coastguard Worker   std::stable_sort(
1962*993b0882SAndroid Build Coastguard Worker       results.begin(), results.end(),
1963*993b0882SAndroid Build Coastguard Worker       [](const ClassificationResult& a, const ClassificationResult& b) {
1964*993b0882SAndroid Build Coastguard Worker         return a.score > b.score;
1965*993b0882SAndroid Build Coastguard Worker       });
1966*993b0882SAndroid Build Coastguard Worker 
1967*993b0882SAndroid Build Coastguard Worker   if (results.empty()) {
1968*993b0882SAndroid Build Coastguard Worker     results = {{Collections::Other(), 1.0}};
1969*993b0882SAndroid Build Coastguard Worker   }
1970*993b0882SAndroid Build Coastguard Worker   return results;
1971*993b0882SAndroid Build Coastguard Worker }
1972*993b0882SAndroid Build Coastguard Worker 
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1973*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelAnnotate(
1974*993b0882SAndroid Build Coastguard Worker     const std::string& context,
1975*993b0882SAndroid Build Coastguard Worker     const std::vector<Locale>& detected_text_language_tags,
1976*993b0882SAndroid Build Coastguard Worker     const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1977*993b0882SAndroid Build Coastguard Worker     std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1978*993b0882SAndroid Build Coastguard Worker   bool skip_model_annotatation = false;
1979*993b0882SAndroid Build Coastguard Worker   if (model_->triggering_options() == nullptr ||
1980*993b0882SAndroid Build Coastguard Worker       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1981*993b0882SAndroid Build Coastguard Worker     skip_model_annotatation = true;
1982*993b0882SAndroid Build Coastguard Worker   }
1983*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1984*993b0882SAndroid Build Coastguard Worker                                     ml_model_triggering_locales_,
1985*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
1986*993b0882SAndroid Build Coastguard Worker     skip_model_annotatation = true;
1987*993b0882SAndroid Build Coastguard Worker   }
1988*993b0882SAndroid Build Coastguard Worker 
1989*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1990*993b0882SAndroid Build Coastguard Worker                                                         /*do_copy=*/false);
1991*993b0882SAndroid Build Coastguard Worker   std::vector<UnicodeTextRange> lines;
1992*993b0882SAndroid Build Coastguard Worker   if (!selection_feature_processor_ ||
1993*993b0882SAndroid Build Coastguard Worker       !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1994*993b0882SAndroid Build Coastguard Worker     lines.push_back({context_unicode.begin(), context_unicode.end()});
1995*993b0882SAndroid Build Coastguard Worker   } else {
1996*993b0882SAndroid Build Coastguard Worker     lines = selection_feature_processor_->SplitContext(
1997*993b0882SAndroid Build Coastguard Worker         context_unicode, selection_feature_processor_->GetOptions()
1998*993b0882SAndroid Build Coastguard Worker                              ->use_pipe_character_for_newline());
1999*993b0882SAndroid Build Coastguard Worker   }
2000*993b0882SAndroid Build Coastguard Worker 
2001*993b0882SAndroid Build Coastguard Worker   const float min_annotate_confidence =
2002*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options() != nullptr
2003*993b0882SAndroid Build Coastguard Worker            ? model_->triggering_options()->min_annotate_confidence()
2004*993b0882SAndroid Build Coastguard Worker            : 0.f);
2005*993b0882SAndroid Build Coastguard Worker 
2006*993b0882SAndroid Build Coastguard Worker   for (const UnicodeTextRange& line : lines) {
2007*993b0882SAndroid Build Coastguard Worker     const std::string line_str =
2008*993b0882SAndroid Build Coastguard Worker         UnicodeText::UTF8Substring(line.first, line.second);
2009*993b0882SAndroid Build Coastguard Worker 
2010*993b0882SAndroid Build Coastguard Worker     std::vector<Token> line_tokens;
2011*993b0882SAndroid Build Coastguard Worker     line_tokens = selection_feature_processor_->Tokenize(line_str);
2012*993b0882SAndroid Build Coastguard Worker 
2013*993b0882SAndroid Build Coastguard Worker     selection_feature_processor_->RetokenizeAndFindClick(
2014*993b0882SAndroid Build Coastguard Worker         line_str, {0, std::distance(line.first, line.second)},
2015*993b0882SAndroid Build Coastguard Worker         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
2016*993b0882SAndroid Build Coastguard Worker         &line_tokens,
2017*993b0882SAndroid Build Coastguard Worker         /*click_pos=*/nullptr);
2018*993b0882SAndroid Build Coastguard Worker     const TokenSpan full_line_span = {
2019*993b0882SAndroid Build Coastguard Worker         0, static_cast<TokenIndex>(line_tokens.size())};
2020*993b0882SAndroid Build Coastguard Worker 
2021*993b0882SAndroid Build Coastguard Worker     tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2022*993b0882SAndroid Build Coastguard Worker 
2023*993b0882SAndroid Build Coastguard Worker     if (skip_model_annotatation) {
2024*993b0882SAndroid Build Coastguard Worker       // We do not annotate, we only output the tokens.
2025*993b0882SAndroid Build Coastguard Worker       continue;
2026*993b0882SAndroid Build Coastguard Worker     }
2027*993b0882SAndroid Build Coastguard Worker 
2028*993b0882SAndroid Build Coastguard Worker     // TODO(zilka): Add support for greater granularity of this check.
2029*993b0882SAndroid Build Coastguard Worker     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2030*993b0882SAndroid Build Coastguard Worker             line_tokens, full_line_span)) {
2031*993b0882SAndroid Build Coastguard Worker       continue;
2032*993b0882SAndroid Build Coastguard Worker     }
2033*993b0882SAndroid Build Coastguard Worker 
2034*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<CachedFeatures> cached_features;
2035*993b0882SAndroid Build Coastguard Worker     if (!selection_feature_processor_->ExtractFeatures(
2036*993b0882SAndroid Build Coastguard Worker             line_tokens, full_line_span,
2037*993b0882SAndroid Build Coastguard Worker             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2038*993b0882SAndroid Build Coastguard Worker             embedding_executor_.get(),
2039*993b0882SAndroid Build Coastguard Worker             /*embedding_cache=*/nullptr,
2040*993b0882SAndroid Build Coastguard Worker             selection_feature_processor_->EmbeddingSize() +
2041*993b0882SAndroid Build Coastguard Worker                 selection_feature_processor_->DenseFeaturesCount(),
2042*993b0882SAndroid Build Coastguard Worker             &cached_features)) {
2043*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not extract features.";
2044*993b0882SAndroid Build Coastguard Worker       return false;
2045*993b0882SAndroid Build Coastguard Worker     }
2046*993b0882SAndroid Build Coastguard Worker 
2047*993b0882SAndroid Build Coastguard Worker     std::vector<TokenSpan> local_chunks;
2048*993b0882SAndroid Build Coastguard Worker     if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2049*993b0882SAndroid Build Coastguard Worker                     interpreter_manager->SelectionInterpreter(),
2050*993b0882SAndroid Build Coastguard Worker                     *cached_features, &local_chunks)) {
2051*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not chunk.";
2052*993b0882SAndroid Build Coastguard Worker       return false;
2053*993b0882SAndroid Build Coastguard Worker     }
2054*993b0882SAndroid Build Coastguard Worker 
2055*993b0882SAndroid Build Coastguard Worker     const int offset = std::distance(context_unicode.begin(), line.first);
2056*993b0882SAndroid Build Coastguard Worker     if (local_chunks.empty()) {
2057*993b0882SAndroid Build Coastguard Worker       continue;
2058*993b0882SAndroid Build Coastguard Worker     }
2059*993b0882SAndroid Build Coastguard Worker     const UnicodeText line_unicode =
2060*993b0882SAndroid Build Coastguard Worker         UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2061*993b0882SAndroid Build Coastguard Worker     std::vector<UnicodeText::const_iterator> line_codepoints =
2062*993b0882SAndroid Build Coastguard Worker         line_unicode.Codepoints();
2063*993b0882SAndroid Build Coastguard Worker     line_codepoints.push_back(line_unicode.end());
2064*993b0882SAndroid Build Coastguard Worker 
2065*993b0882SAndroid Build Coastguard Worker     FeatureProcessor::EmbeddingCache embedding_cache;
2066*993b0882SAndroid Build Coastguard Worker     for (const TokenSpan& chunk : local_chunks) {
2067*993b0882SAndroid Build Coastguard Worker       CodepointSpan codepoint_span =
2068*993b0882SAndroid Build Coastguard Worker           TokenSpanToCodepointSpan(line_tokens, chunk);
2069*993b0882SAndroid Build Coastguard Worker       if (!codepoint_span.IsValid() ||
2070*993b0882SAndroid Build Coastguard Worker           codepoint_span.second > line_codepoints.size()) {
2071*993b0882SAndroid Build Coastguard Worker         continue;
2072*993b0882SAndroid Build Coastguard Worker       }
2073*993b0882SAndroid Build Coastguard Worker       codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2074*993b0882SAndroid Build Coastguard Worker           /*span_begin=*/line_codepoints[codepoint_span.first],
2075*993b0882SAndroid Build Coastguard Worker           /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
2076*993b0882SAndroid Build Coastguard Worker       if (model_->selection_options()->strip_unpaired_brackets()) {
2077*993b0882SAndroid Build Coastguard Worker         codepoint_span = StripUnpairedBrackets(
2078*993b0882SAndroid Build Coastguard Worker             /*span_begin=*/line_codepoints[codepoint_span.first],
2079*993b0882SAndroid Build Coastguard Worker             /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
2080*993b0882SAndroid Build Coastguard Worker             *unilib_);
2081*993b0882SAndroid Build Coastguard Worker       }
2082*993b0882SAndroid Build Coastguard Worker 
2083*993b0882SAndroid Build Coastguard Worker       // Skip empty spans.
2084*993b0882SAndroid Build Coastguard Worker       if (codepoint_span.first != codepoint_span.second) {
2085*993b0882SAndroid Build Coastguard Worker         std::vector<ClassificationResult> classification;
2086*993b0882SAndroid Build Coastguard Worker         if (!ModelClassifyText(
2087*993b0882SAndroid Build Coastguard Worker                 line_unicode, line_tokens, detected_text_language_tags,
2088*993b0882SAndroid Build Coastguard Worker                 /*span_begin=*/line_codepoints[codepoint_span.first],
2089*993b0882SAndroid Build Coastguard Worker                 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2090*993b0882SAndroid Build Coastguard Worker                 codepoint_span, options, interpreter_manager, &embedding_cache,
2091*993b0882SAndroid Build Coastguard Worker                 &classification, /*tokens=*/nullptr)) {
2092*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR) << "Could not classify text: "
2093*993b0882SAndroid Build Coastguard Worker                          << (codepoint_span.first + offset) << " "
2094*993b0882SAndroid Build Coastguard Worker                          << (codepoint_span.second + offset);
2095*993b0882SAndroid Build Coastguard Worker           return false;
2096*993b0882SAndroid Build Coastguard Worker         }
2097*993b0882SAndroid Build Coastguard Worker 
2098*993b0882SAndroid Build Coastguard Worker         // Do not include the span if it's classified as "other".
2099*993b0882SAndroid Build Coastguard Worker         if (!classification.empty() && !ClassifiedAsOther(classification) &&
2100*993b0882SAndroid Build Coastguard Worker             classification[0].score >= min_annotate_confidence) {
2101*993b0882SAndroid Build Coastguard Worker           AnnotatedSpan result_span;
2102*993b0882SAndroid Build Coastguard Worker           result_span.span = {codepoint_span.first + offset,
2103*993b0882SAndroid Build Coastguard Worker                               codepoint_span.second + offset};
2104*993b0882SAndroid Build Coastguard Worker           result_span.classification = std::move(classification);
2105*993b0882SAndroid Build Coastguard Worker           result->push_back(std::move(result_span));
2106*993b0882SAndroid Build Coastguard Worker         }
2107*993b0882SAndroid Build Coastguard Worker       }
2108*993b0882SAndroid Build Coastguard Worker     }
2109*993b0882SAndroid Build Coastguard Worker   }
2110*993b0882SAndroid Build Coastguard Worker   return true;
2111*993b0882SAndroid Build Coastguard Worker }
2112*993b0882SAndroid Build Coastguard Worker 
SelectionFeatureProcessorForTests() const2113*993b0882SAndroid Build Coastguard Worker const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2114*993b0882SAndroid Build Coastguard Worker   return selection_feature_processor_.get();
2115*993b0882SAndroid Build Coastguard Worker }
2116*993b0882SAndroid Build Coastguard Worker 
ClassificationFeatureProcessorForTests() const2117*993b0882SAndroid Build Coastguard Worker const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2118*993b0882SAndroid Build Coastguard Worker     const {
2119*993b0882SAndroid Build Coastguard Worker   return classification_feature_processor_.get();
2120*993b0882SAndroid Build Coastguard Worker }
2121*993b0882SAndroid Build Coastguard Worker 
DatetimeParserForTests() const2122*993b0882SAndroid Build Coastguard Worker const DatetimeParser* Annotator::DatetimeParserForTests() const {
2123*993b0882SAndroid Build Coastguard Worker   return datetime_parser_.get();
2124*993b0882SAndroid Build Coastguard Worker }
2125*993b0882SAndroid Build Coastguard Worker 
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2126*993b0882SAndroid Build Coastguard Worker void Annotator::RemoveNotEnabledEntityTypes(
2127*993b0882SAndroid Build Coastguard Worker     const EnabledEntityTypes& is_entity_type_enabled,
2128*993b0882SAndroid Build Coastguard Worker     std::vector<AnnotatedSpan>* annotated_spans) const {
2129*993b0882SAndroid Build Coastguard Worker   for (AnnotatedSpan& annotated_span : *annotated_spans) {
2130*993b0882SAndroid Build Coastguard Worker     std::vector<ClassificationResult>& classifications =
2131*993b0882SAndroid Build Coastguard Worker         annotated_span.classification;
2132*993b0882SAndroid Build Coastguard Worker     classifications.erase(
2133*993b0882SAndroid Build Coastguard Worker         std::remove_if(classifications.begin(), classifications.end(),
2134*993b0882SAndroid Build Coastguard Worker                        [&is_entity_type_enabled](
2135*993b0882SAndroid Build Coastguard Worker                            const ClassificationResult& classification_result) {
2136*993b0882SAndroid Build Coastguard Worker                          return !is_entity_type_enabled(
2137*993b0882SAndroid Build Coastguard Worker                              classification_result.collection);
2138*993b0882SAndroid Build Coastguard Worker                        }),
2139*993b0882SAndroid Build Coastguard Worker         classifications.end());
2140*993b0882SAndroid Build Coastguard Worker   }
2141*993b0882SAndroid Build Coastguard Worker   annotated_spans->erase(
2142*993b0882SAndroid Build Coastguard Worker       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2143*993b0882SAndroid Build Coastguard Worker                      [](const AnnotatedSpan& annotated_span) {
2144*993b0882SAndroid Build Coastguard Worker                        return annotated_span.classification.empty();
2145*993b0882SAndroid Build Coastguard Worker                      }),
2146*993b0882SAndroid Build Coastguard Worker       annotated_spans->end());
2147*993b0882SAndroid Build Coastguard Worker }
2148*993b0882SAndroid Build Coastguard Worker 
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2149*993b0882SAndroid Build Coastguard Worker void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2150*993b0882SAndroid Build Coastguard Worker     std::vector<AnnotatedSpan>* candidates) const {
2151*993b0882SAndroid Build Coastguard Worker   if (candidates == nullptr || contact_engine_ == nullptr) {
2152*993b0882SAndroid Build Coastguard Worker     return;
2153*993b0882SAndroid Build Coastguard Worker   }
2154*993b0882SAndroid Build Coastguard Worker   for (auto& candidate : *candidates) {
2155*993b0882SAndroid Build Coastguard Worker     for (auto& classification_result : candidate.classification) {
2156*993b0882SAndroid Build Coastguard Worker       contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2157*993b0882SAndroid Build Coastguard Worker           &classification_result);
2158*993b0882SAndroid Build Coastguard Worker     }
2159*993b0882SAndroid Build Coastguard Worker   }
2160*993b0882SAndroid Build Coastguard Worker }
2161*993b0882SAndroid Build Coastguard Worker 
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2162*993b0882SAndroid Build Coastguard Worker Status Annotator::AnnotateSingleInput(
2163*993b0882SAndroid Build Coastguard Worker     const std::string& context, const AnnotationOptions& options,
2164*993b0882SAndroid Build Coastguard Worker     std::vector<AnnotatedSpan>* candidates) const {
2165*993b0882SAndroid Build Coastguard Worker   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2166*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2167*993b0882SAndroid Build Coastguard Worker   }
2168*993b0882SAndroid Build Coastguard Worker 
2169*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode =
2170*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false);
2171*993b0882SAndroid Build Coastguard Worker 
2172*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> detected_text_language_tags;
2173*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(options.detected_text_language_tags,
2174*993b0882SAndroid Build Coastguard Worker                     &detected_text_language_tags)) {
2175*993b0882SAndroid Build Coastguard Worker     TC3_LOG(WARNING)
2176*993b0882SAndroid Build Coastguard Worker         << "Failed to parse the detected_text_language_tags in options: "
2177*993b0882SAndroid Build Coastguard Worker         << options.detected_text_language_tags;
2178*993b0882SAndroid Build Coastguard Worker   }
2179*993b0882SAndroid Build Coastguard Worker   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2180*993b0882SAndroid Build Coastguard Worker                                     model_triggering_locales_,
2181*993b0882SAndroid Build Coastguard Worker                                     /*default_value=*/true)) {
2182*993b0882SAndroid Build Coastguard Worker     return Status(
2183*993b0882SAndroid Build Coastguard Worker         StatusCode::UNAVAILABLE,
2184*993b0882SAndroid Build Coastguard Worker         "The detected language tags are not in the supported locales.");
2185*993b0882SAndroid Build Coastguard Worker   }
2186*993b0882SAndroid Build Coastguard Worker 
2187*993b0882SAndroid Build Coastguard Worker   InterpreterManager interpreter_manager(selection_executor_.get(),
2188*993b0882SAndroid Build Coastguard Worker                                          classification_executor_.get());
2189*993b0882SAndroid Build Coastguard Worker 
2190*993b0882SAndroid Build Coastguard Worker   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2191*993b0882SAndroid Build Coastguard Worker   const bool is_raw_usecase =
2192*993b0882SAndroid Build Coastguard Worker       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2193*993b0882SAndroid Build Coastguard Worker 
2194*993b0882SAndroid Build Coastguard Worker   // Annotate with the selection model.
2195*993b0882SAndroid Build Coastguard Worker   const bool model_annotations_enabled =
2196*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2197*993b0882SAndroid Build Coastguard Worker   std::vector<Token> tokens;
2198*993b0882SAndroid Build Coastguard Worker   if (model_annotations_enabled &&
2199*993b0882SAndroid Build Coastguard Worker       !ModelAnnotate(context, detected_text_language_tags, options,
2200*993b0882SAndroid Build Coastguard Worker                      &interpreter_manager, &tokens, candidates)) {
2201*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2202*993b0882SAndroid Build Coastguard Worker   } else if (!model_annotations_enabled) {
2203*993b0882SAndroid Build Coastguard Worker     // If the ML model didn't run, we need to tokenize to support the other
2204*993b0882SAndroid Build Coastguard Worker     // annotators that depend on the tokens.
2205*993b0882SAndroid Build Coastguard Worker     // Optimization could be made to only do this when an annotator that uses
2206*993b0882SAndroid Build Coastguard Worker     // the tokens is enabled, but it's unclear if the added complexity is worth
2207*993b0882SAndroid Build Coastguard Worker     // it.
2208*993b0882SAndroid Build Coastguard Worker     if (selection_feature_processor_ != nullptr) {
2209*993b0882SAndroid Build Coastguard Worker       tokens = selection_feature_processor_->Tokenize(context_unicode);
2210*993b0882SAndroid Build Coastguard Worker     }
2211*993b0882SAndroid Build Coastguard Worker   }
2212*993b0882SAndroid Build Coastguard Worker 
2213*993b0882SAndroid Build Coastguard Worker   // Annotate with the regular expression models.
2214*993b0882SAndroid Build Coastguard Worker   const bool regex_annotations_enabled =
2215*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2216*993b0882SAndroid Build Coastguard Worker   if (regex_annotations_enabled &&
2217*993b0882SAndroid Build Coastguard Worker       !RegexChunk(
2218*993b0882SAndroid Build Coastguard Worker           UTF8ToUnicodeText(context, /*do_copy=*/false),
2219*993b0882SAndroid Build Coastguard Worker           annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2220*993b0882SAndroid Build Coastguard Worker           is_entity_type_enabled, options.annotation_usecase, candidates)) {
2221*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2222*993b0882SAndroid Build Coastguard Worker   }
2223*993b0882SAndroid Build Coastguard Worker 
2224*993b0882SAndroid Build Coastguard Worker   // Annotate with the datetime model.
2225*993b0882SAndroid Build Coastguard Worker   // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2226*993b0882SAndroid Build Coastguard Worker   // relatively slow for some clients.
2227*993b0882SAndroid Build Coastguard Worker   if ((is_entity_type_enabled(Collections::Date()) ||
2228*993b0882SAndroid Build Coastguard Worker        is_entity_type_enabled(Collections::DateTime())) &&
2229*993b0882SAndroid Build Coastguard Worker       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2230*993b0882SAndroid Build Coastguard Worker                      options.reference_time_ms_utc, options.reference_timezone,
2231*993b0882SAndroid Build Coastguard Worker                      options.locales, ModeFlag_ANNOTATION,
2232*993b0882SAndroid Build Coastguard Worker                      options.annotation_usecase,
2233*993b0882SAndroid Build Coastguard Worker                      options.is_serialized_entity_data_enabled, candidates)) {
2234*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2235*993b0882SAndroid Build Coastguard Worker   }
2236*993b0882SAndroid Build Coastguard Worker 
2237*993b0882SAndroid Build Coastguard Worker   // Annotate with the contact engine.
2238*993b0882SAndroid Build Coastguard Worker   const bool contact_annotations_enabled =
2239*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2240*993b0882SAndroid Build Coastguard Worker   if (contact_annotations_enabled && contact_engine_ &&
2241*993b0882SAndroid Build Coastguard Worker       !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2242*993b0882SAndroid Build Coastguard Worker                               candidates)) {
2243*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2244*993b0882SAndroid Build Coastguard Worker   }
2245*993b0882SAndroid Build Coastguard Worker 
2246*993b0882SAndroid Build Coastguard Worker   // Annotate with the installed app engine.
2247*993b0882SAndroid Build Coastguard Worker   const bool app_annotations_enabled =
2248*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::App());
2249*993b0882SAndroid Build Coastguard Worker   if (app_annotations_enabled && installed_app_engine_ &&
2250*993b0882SAndroid Build Coastguard Worker       !installed_app_engine_->Chunk(context_unicode, tokens,
2251*993b0882SAndroid Build Coastguard Worker                                     ModeFlag_ANNOTATION, candidates)) {
2252*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL,
2253*993b0882SAndroid Build Coastguard Worker                   "Couldn't run installed app engine Chunk.");
2254*993b0882SAndroid Build Coastguard Worker   }
2255*993b0882SAndroid Build Coastguard Worker 
2256*993b0882SAndroid Build Coastguard Worker   // Annotate with the number annotator.
2257*993b0882SAndroid Build Coastguard Worker   const bool number_annotations_enabled =
2258*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2259*993b0882SAndroid Build Coastguard Worker                           is_entity_type_enabled(Collections::Percentage()));
2260*993b0882SAndroid Build Coastguard Worker   if (number_annotations_enabled && number_annotator_ != nullptr &&
2261*993b0882SAndroid Build Coastguard Worker       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2262*993b0882SAndroid Build Coastguard Worker                                   ModeFlag_ANNOTATION, candidates)) {
2263*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL,
2264*993b0882SAndroid Build Coastguard Worker                   "Couldn't run number annotator FindAll.");
2265*993b0882SAndroid Build Coastguard Worker   }
2266*993b0882SAndroid Build Coastguard Worker 
2267*993b0882SAndroid Build Coastguard Worker   // Annotate with the duration annotator.
2268*993b0882SAndroid Build Coastguard Worker   const bool duration_annotations_enabled =
2269*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2270*993b0882SAndroid Build Coastguard Worker   if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2271*993b0882SAndroid Build Coastguard Worker       !duration_annotator_->FindAll(context_unicode, tokens,
2272*993b0882SAndroid Build Coastguard Worker                                     options.annotation_usecase,
2273*993b0882SAndroid Build Coastguard Worker                                     ModeFlag_ANNOTATION, candidates)) {
2274*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL,
2275*993b0882SAndroid Build Coastguard Worker                   "Couldn't run duration annotator FindAll.");
2276*993b0882SAndroid Build Coastguard Worker   }
2277*993b0882SAndroid Build Coastguard Worker 
2278*993b0882SAndroid Build Coastguard Worker   // Annotate with the person name engine.
2279*993b0882SAndroid Build Coastguard Worker   const bool person_annotations_enabled =
2280*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2281*993b0882SAndroid Build Coastguard Worker   if (person_annotations_enabled && person_name_engine_ &&
2282*993b0882SAndroid Build Coastguard Worker       !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2283*993b0882SAndroid Build Coastguard Worker                                   candidates)) {
2284*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL,
2285*993b0882SAndroid Build Coastguard Worker                   "Couldn't run person name engine Chunk.");
2286*993b0882SAndroid Build Coastguard Worker   }
2287*993b0882SAndroid Build Coastguard Worker 
2288*993b0882SAndroid Build Coastguard Worker   // Annotate with the grammar annotators.
2289*993b0882SAndroid Build Coastguard Worker   if (grammar_annotator_ != nullptr &&
2290*993b0882SAndroid Build Coastguard Worker       !grammar_annotator_->Annotate(detected_text_language_tags,
2291*993b0882SAndroid Build Coastguard Worker                                     context_unicode, candidates)) {
2292*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2293*993b0882SAndroid Build Coastguard Worker   }
2294*993b0882SAndroid Build Coastguard Worker 
2295*993b0882SAndroid Build Coastguard Worker   // Annotate with the POD NER annotator.
2296*993b0882SAndroid Build Coastguard Worker   const bool pod_ner_annotations_enabled =
2297*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2298*993b0882SAndroid Build Coastguard Worker   if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2299*993b0882SAndroid Build Coastguard Worker       options.use_pod_ner &&
2300*993b0882SAndroid Build Coastguard Worker       !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2301*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2302*993b0882SAndroid Build Coastguard Worker   }
2303*993b0882SAndroid Build Coastguard Worker 
2304*993b0882SAndroid Build Coastguard Worker   // Annotate with the vocab annotator.
2305*993b0882SAndroid Build Coastguard Worker   const bool vocab_annotations_enabled =
2306*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2307*993b0882SAndroid Build Coastguard Worker   if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2308*993b0882SAndroid Build Coastguard Worker       options.use_vocab_annotator &&
2309*993b0882SAndroid Build Coastguard Worker       !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2310*993b0882SAndroid Build Coastguard Worker                                   options.trigger_dictionary_on_beginner_words,
2311*993b0882SAndroid Build Coastguard Worker                                   candidates)) {
2312*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2313*993b0882SAndroid Build Coastguard Worker   }
2314*993b0882SAndroid Build Coastguard Worker 
2315*993b0882SAndroid Build Coastguard Worker   // Annotate with the experimental annotator.
2316*993b0882SAndroid Build Coastguard Worker   if (experimental_annotator_ != nullptr &&
2317*993b0882SAndroid Build Coastguard Worker       (model_->triggering_options()->experimental_enabled_modes() &
2318*993b0882SAndroid Build Coastguard Worker        ModeFlag_ANNOTATION) &&
2319*993b0882SAndroid Build Coastguard Worker       !experimental_annotator_->Annotate(context_unicode, candidates)) {
2320*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2321*993b0882SAndroid Build Coastguard Worker   }
2322*993b0882SAndroid Build Coastguard Worker 
2323*993b0882SAndroid Build Coastguard Worker   // Sort candidates according to their position in the input, so that the next
2324*993b0882SAndroid Build Coastguard Worker   // code can assume that any connected component of overlapping spans forms a
2325*993b0882SAndroid Build Coastguard Worker   // contiguous block.
2326*993b0882SAndroid Build Coastguard Worker   // Also sort them according to the end position and collection, so that the
2327*993b0882SAndroid Build Coastguard Worker   // deduplication code below can assume that same spans and classifications
2328*993b0882SAndroid Build Coastguard Worker   // form contiguous blocks.
2329*993b0882SAndroid Build Coastguard Worker   std::stable_sort(candidates->begin(), candidates->end(),
2330*993b0882SAndroid Build Coastguard Worker                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2331*993b0882SAndroid Build Coastguard Worker                      if (a.span.first != b.span.first) {
2332*993b0882SAndroid Build Coastguard Worker                        return a.span.first < b.span.first;
2333*993b0882SAndroid Build Coastguard Worker                      }
2334*993b0882SAndroid Build Coastguard Worker 
2335*993b0882SAndroid Build Coastguard Worker                      if (a.span.second != b.span.second) {
2336*993b0882SAndroid Build Coastguard Worker                        return a.span.second < b.span.second;
2337*993b0882SAndroid Build Coastguard Worker                      }
2338*993b0882SAndroid Build Coastguard Worker 
2339*993b0882SAndroid Build Coastguard Worker                      return a.classification[0].collection <
2340*993b0882SAndroid Build Coastguard Worker                             b.classification[0].collection;
2341*993b0882SAndroid Build Coastguard Worker                    });
2342*993b0882SAndroid Build Coastguard Worker 
2343*993b0882SAndroid Build Coastguard Worker   std::vector<int> candidate_indices;
2344*993b0882SAndroid Build Coastguard Worker   if (!ResolveConflicts(*candidates, context, tokens,
2345*993b0882SAndroid Build Coastguard Worker                         detected_text_language_tags, options,
2346*993b0882SAndroid Build Coastguard Worker                         &interpreter_manager, &candidate_indices)) {
2347*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2348*993b0882SAndroid Build Coastguard Worker   }
2349*993b0882SAndroid Build Coastguard Worker 
2350*993b0882SAndroid Build Coastguard Worker   // Remove candidates that overlap exactly and have the same collection.
2351*993b0882SAndroid Build Coastguard Worker   // This can e.g. happen for phone coming from both ML model and regex.
2352*993b0882SAndroid Build Coastguard Worker   candidate_indices.erase(
2353*993b0882SAndroid Build Coastguard Worker       std::unique(candidate_indices.begin(), candidate_indices.end(),
2354*993b0882SAndroid Build Coastguard Worker                   [&candidates](const int a_index, const int b_index) {
2355*993b0882SAndroid Build Coastguard Worker                     const AnnotatedSpan& a = (*candidates)[a_index];
2356*993b0882SAndroid Build Coastguard Worker                     const AnnotatedSpan& b = (*candidates)[b_index];
2357*993b0882SAndroid Build Coastguard Worker                     return a.span == b.span &&
2358*993b0882SAndroid Build Coastguard Worker                            a.classification[0].collection ==
2359*993b0882SAndroid Build Coastguard Worker                                b.classification[0].collection;
2360*993b0882SAndroid Build Coastguard Worker                   }),
2361*993b0882SAndroid Build Coastguard Worker       candidate_indices.end());
2362*993b0882SAndroid Build Coastguard Worker 
2363*993b0882SAndroid Build Coastguard Worker   std::vector<AnnotatedSpan> result;
2364*993b0882SAndroid Build Coastguard Worker   result.reserve(candidate_indices.size());
2365*993b0882SAndroid Build Coastguard Worker   for (const int i : candidate_indices) {
2366*993b0882SAndroid Build Coastguard Worker     if ((*candidates)[i].classification.empty() ||
2367*993b0882SAndroid Build Coastguard Worker         ClassifiedAsOther((*candidates)[i].classification) ||
2368*993b0882SAndroid Build Coastguard Worker         FilteredForAnnotation((*candidates)[i])) {
2369*993b0882SAndroid Build Coastguard Worker       continue;
2370*993b0882SAndroid Build Coastguard Worker     }
2371*993b0882SAndroid Build Coastguard Worker     result.push_back(std::move((*candidates)[i]));
2372*993b0882SAndroid Build Coastguard Worker   }
2373*993b0882SAndroid Build Coastguard Worker 
2374*993b0882SAndroid Build Coastguard Worker   // We generate all candidates and remove them later (with the exception of
2375*993b0882SAndroid Build Coastguard Worker   // date/time/duration entities) because there are complex interdependencies
2376*993b0882SAndroid Build Coastguard Worker   // between the entity types. E.g., the TLD of an email can be interpreted as a
2377*993b0882SAndroid Build Coastguard Worker   // URL, but most likely a user of the API does not want such annotations if
2378*993b0882SAndroid Build Coastguard Worker   // "url" is enabled and "email" is not.
2379*993b0882SAndroid Build Coastguard Worker   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2380*993b0882SAndroid Build Coastguard Worker 
2381*993b0882SAndroid Build Coastguard Worker   for (AnnotatedSpan& annotated_span : result) {
2382*993b0882SAndroid Build Coastguard Worker     SortClassificationResults(&annotated_span.classification);
2383*993b0882SAndroid Build Coastguard Worker   }
2384*993b0882SAndroid Build Coastguard Worker   *candidates = result;
2385*993b0882SAndroid Build Coastguard Worker   return Status::OK;
2386*993b0882SAndroid Build Coastguard Worker }
2387*993b0882SAndroid Build Coastguard Worker 
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2388*993b0882SAndroid Build Coastguard Worker StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2389*993b0882SAndroid Build Coastguard Worker     const std::vector<InputFragment>& string_fragments,
2390*993b0882SAndroid Build Coastguard Worker     const AnnotationOptions& options) const {
2391*993b0882SAndroid Build Coastguard Worker   Annotations annotation_candidates;
2392*993b0882SAndroid Build Coastguard Worker   annotation_candidates.annotated_spans.resize(string_fragments.size());
2393*993b0882SAndroid Build Coastguard Worker 
2394*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> text_to_annotate;
2395*993b0882SAndroid Build Coastguard Worker   text_to_annotate.reserve(string_fragments.size());
2396*993b0882SAndroid Build Coastguard Worker   std::vector<FragmentMetadata> fragment_metadata;
2397*993b0882SAndroid Build Coastguard Worker   fragment_metadata.reserve(string_fragments.size());
2398*993b0882SAndroid Build Coastguard Worker   for (const auto& string_fragment : string_fragments) {
2399*993b0882SAndroid Build Coastguard Worker     text_to_annotate.push_back(string_fragment.text);
2400*993b0882SAndroid Build Coastguard Worker     fragment_metadata.push_back(
2401*993b0882SAndroid Build Coastguard Worker         {.relative_bounding_box_top = string_fragment.bounding_box_top,
2402*993b0882SAndroid Build Coastguard Worker          .relative_bounding_box_height = string_fragment.bounding_box_height});
2403*993b0882SAndroid Build Coastguard Worker   }
2404*993b0882SAndroid Build Coastguard Worker 
2405*993b0882SAndroid Build Coastguard Worker   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2406*993b0882SAndroid Build Coastguard Worker   const bool is_raw_usecase =
2407*993b0882SAndroid Build Coastguard Worker       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2408*993b0882SAndroid Build Coastguard Worker 
2409*993b0882SAndroid Build Coastguard Worker   const bool knowledge_engine_annotations_enabled =
2410*993b0882SAndroid Build Coastguard Worker       !is_raw_usecase || is_entity_type_enabled(Collections::Entity());
2411*993b0882SAndroid Build Coastguard Worker   // KnowledgeEngine is special, because it supports annotation of multiple
2412*993b0882SAndroid Build Coastguard Worker   // fragments at once.
2413*993b0882SAndroid Build Coastguard Worker   if (knowledge_engine_annotations_enabled && knowledge_engine_ &&
2414*993b0882SAndroid Build Coastguard Worker       !knowledge_engine_
2415*993b0882SAndroid Build Coastguard Worker            ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2416*993b0882SAndroid Build Coastguard Worker                                 options.annotation_usecase,
2417*993b0882SAndroid Build Coastguard Worker                                 options.location_context, options.permissions,
2418*993b0882SAndroid Build Coastguard Worker                                 options.annotate_mode, ModeFlag_ANNOTATION,
2419*993b0882SAndroid Build Coastguard Worker                                 &annotation_candidates)
2420*993b0882SAndroid Build Coastguard Worker            .ok()) {
2421*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2422*993b0882SAndroid Build Coastguard Worker   }
2423*993b0882SAndroid Build Coastguard Worker   // The annotator engines shouldn't change the number of annotation vectors.
2424*993b0882SAndroid Build Coastguard Worker   if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2425*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2426*993b0882SAndroid Build Coastguard Worker                    << " texts to annotate but generated a different number of  "
2427*993b0882SAndroid Build Coastguard Worker                       "lists of annotations:"
2428*993b0882SAndroid Build Coastguard Worker                    << annotation_candidates.annotated_spans.size();
2429*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INTERNAL,
2430*993b0882SAndroid Build Coastguard Worker                   "Number of annotation candidates differs from "
2431*993b0882SAndroid Build Coastguard Worker                   "number of texts to annotate.");
2432*993b0882SAndroid Build Coastguard Worker   }
2433*993b0882SAndroid Build Coastguard Worker 
2434*993b0882SAndroid Build Coastguard Worker   // As an optimization, if the only annotated type is Entity, we skip all the
2435*993b0882SAndroid Build Coastguard Worker   // other annotators than the KnowledgeEngine. This only happens in the raw
2436*993b0882SAndroid Build Coastguard Worker   // mode, to make sure it does not affect the result.
2437*993b0882SAndroid Build Coastguard Worker   if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2438*993b0882SAndroid Build Coastguard Worker       options.entity_types.size() == 1 &&
2439*993b0882SAndroid Build Coastguard Worker       *options.entity_types.begin() == Collections::Entity()) {
2440*993b0882SAndroid Build Coastguard Worker     return annotation_candidates;
2441*993b0882SAndroid Build Coastguard Worker   }
2442*993b0882SAndroid Build Coastguard Worker 
2443*993b0882SAndroid Build Coastguard Worker   // Other annotators run on each fragment independently.
2444*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < text_to_annotate.size(); ++i) {
2445*993b0882SAndroid Build Coastguard Worker     AnnotationOptions annotation_options = options;
2446*993b0882SAndroid Build Coastguard Worker     if (string_fragments[i].datetime_options.has_value()) {
2447*993b0882SAndroid Build Coastguard Worker       DatetimeOptions reference_datetime =
2448*993b0882SAndroid Build Coastguard Worker           string_fragments[i].datetime_options.value();
2449*993b0882SAndroid Build Coastguard Worker       annotation_options.reference_time_ms_utc =
2450*993b0882SAndroid Build Coastguard Worker           reference_datetime.reference_time_ms_utc;
2451*993b0882SAndroid Build Coastguard Worker       annotation_options.reference_timezone =
2452*993b0882SAndroid Build Coastguard Worker           reference_datetime.reference_timezone;
2453*993b0882SAndroid Build Coastguard Worker     }
2454*993b0882SAndroid Build Coastguard Worker 
2455*993b0882SAndroid Build Coastguard Worker     AddContactMetadataToKnowledgeClassificationResults(
2456*993b0882SAndroid Build Coastguard Worker         &annotation_candidates.annotated_spans[i]);
2457*993b0882SAndroid Build Coastguard Worker 
2458*993b0882SAndroid Build Coastguard Worker     Status annotation_status =
2459*993b0882SAndroid Build Coastguard Worker         AnnotateSingleInput(text_to_annotate[i], annotation_options,
2460*993b0882SAndroid Build Coastguard Worker                             &annotation_candidates.annotated_spans[i]);
2461*993b0882SAndroid Build Coastguard Worker     if (!annotation_status.ok()) {
2462*993b0882SAndroid Build Coastguard Worker       return annotation_status;
2463*993b0882SAndroid Build Coastguard Worker     }
2464*993b0882SAndroid Build Coastguard Worker   }
2465*993b0882SAndroid Build Coastguard Worker   return annotation_candidates;
2466*993b0882SAndroid Build Coastguard Worker }
2467*993b0882SAndroid Build Coastguard Worker 
Annotate(const std::string & context,const AnnotationOptions & options) const2468*993b0882SAndroid Build Coastguard Worker std::vector<AnnotatedSpan> Annotator::Annotate(
2469*993b0882SAndroid Build Coastguard Worker     const std::string& context, const AnnotationOptions& options) const {
2470*993b0882SAndroid Build Coastguard Worker   if (context.size() > std::numeric_limits<int>::max()) {
2471*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting too long input.";
2472*993b0882SAndroid Build Coastguard Worker     return {};
2473*993b0882SAndroid Build Coastguard Worker   }
2474*993b0882SAndroid Build Coastguard Worker 
2475*993b0882SAndroid Build Coastguard Worker   const UnicodeText context_unicode =
2476*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(context, /*do_copy=*/false);
2477*993b0882SAndroid Build Coastguard Worker   if (!unilib_->IsValidUtf8(context_unicode)) {
2478*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2479*993b0882SAndroid Build Coastguard Worker     return {};
2480*993b0882SAndroid Build Coastguard Worker   }
2481*993b0882SAndroid Build Coastguard Worker 
2482*993b0882SAndroid Build Coastguard Worker   std::vector<InputFragment> string_fragments;
2483*993b0882SAndroid Build Coastguard Worker   string_fragments.push_back({.text = context});
2484*993b0882SAndroid Build Coastguard Worker   StatusOr<Annotations> annotations =
2485*993b0882SAndroid Build Coastguard Worker       AnnotateStructuredInput(string_fragments, options);
2486*993b0882SAndroid Build Coastguard Worker   if (!annotations.ok()) {
2487*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2488*993b0882SAndroid Build Coastguard Worker                    << annotations.status().error_message();
2489*993b0882SAndroid Build Coastguard Worker     return {};
2490*993b0882SAndroid Build Coastguard Worker   }
2491*993b0882SAndroid Build Coastguard Worker   return annotations.ValueOrDie().annotated_spans[0];
2492*993b0882SAndroid Build Coastguard Worker }
2493*993b0882SAndroid Build Coastguard Worker 
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2494*993b0882SAndroid Build Coastguard Worker CodepointSpan Annotator::ComputeSelectionBoundaries(
2495*993b0882SAndroid Build Coastguard Worker     const UniLib::RegexMatcher* match,
2496*993b0882SAndroid Build Coastguard Worker     const RegexModel_::Pattern* config) const {
2497*993b0882SAndroid Build Coastguard Worker   if (config->capturing_group() == nullptr) {
2498*993b0882SAndroid Build Coastguard Worker     // Use first capturing group to specify the selection.
2499*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
2500*993b0882SAndroid Build Coastguard Worker     const CodepointSpan result = {match->Start(1, &status),
2501*993b0882SAndroid Build Coastguard Worker                                   match->End(1, &status)};
2502*993b0882SAndroid Build Coastguard Worker     if (status != UniLib::RegexMatcher::kNoError) {
2503*993b0882SAndroid Build Coastguard Worker       return {kInvalidIndex, kInvalidIndex};
2504*993b0882SAndroid Build Coastguard Worker     }
2505*993b0882SAndroid Build Coastguard Worker     return result;
2506*993b0882SAndroid Build Coastguard Worker   }
2507*993b0882SAndroid Build Coastguard Worker 
2508*993b0882SAndroid Build Coastguard Worker   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2509*993b0882SAndroid Build Coastguard Worker   const int num_groups = config->capturing_group()->size();
2510*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_groups; i++) {
2511*993b0882SAndroid Build Coastguard Worker     if (!config->capturing_group()->Get(i)->extend_selection()) {
2512*993b0882SAndroid Build Coastguard Worker       continue;
2513*993b0882SAndroid Build Coastguard Worker     }
2514*993b0882SAndroid Build Coastguard Worker 
2515*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
2516*993b0882SAndroid Build Coastguard Worker     // Check match and adjust bounds.
2517*993b0882SAndroid Build Coastguard Worker     const int group_start = match->Start(i, &status);
2518*993b0882SAndroid Build Coastguard Worker     const int group_end = match->End(i, &status);
2519*993b0882SAndroid Build Coastguard Worker     if (status != UniLib::RegexMatcher::kNoError) {
2520*993b0882SAndroid Build Coastguard Worker       return {kInvalidIndex, kInvalidIndex};
2521*993b0882SAndroid Build Coastguard Worker     }
2522*993b0882SAndroid Build Coastguard Worker     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2523*993b0882SAndroid Build Coastguard Worker       continue;
2524*993b0882SAndroid Build Coastguard Worker     }
2525*993b0882SAndroid Build Coastguard Worker     if (result.first == kInvalidIndex) {
2526*993b0882SAndroid Build Coastguard Worker       result = {group_start, group_end};
2527*993b0882SAndroid Build Coastguard Worker     } else {
2528*993b0882SAndroid Build Coastguard Worker       result.first = std::min(result.first, group_start);
2529*993b0882SAndroid Build Coastguard Worker       result.second = std::max(result.second, group_end);
2530*993b0882SAndroid Build Coastguard Worker     }
2531*993b0882SAndroid Build Coastguard Worker   }
2532*993b0882SAndroid Build Coastguard Worker   return result;
2533*993b0882SAndroid Build Coastguard Worker }
2534*993b0882SAndroid Build Coastguard Worker 
HasEntityData(const RegexModel_::Pattern * pattern) const2535*993b0882SAndroid Build Coastguard Worker bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2536*993b0882SAndroid Build Coastguard Worker   if (pattern->serialized_entity_data() != nullptr ||
2537*993b0882SAndroid Build Coastguard Worker       pattern->entity_data() != nullptr) {
2538*993b0882SAndroid Build Coastguard Worker     return true;
2539*993b0882SAndroid Build Coastguard Worker   }
2540*993b0882SAndroid Build Coastguard Worker   if (pattern->capturing_group() != nullptr) {
2541*993b0882SAndroid Build Coastguard Worker     for (const CapturingGroup* group : *pattern->capturing_group()) {
2542*993b0882SAndroid Build Coastguard Worker       if (group->entity_field_path() != nullptr) {
2543*993b0882SAndroid Build Coastguard Worker         return true;
2544*993b0882SAndroid Build Coastguard Worker       }
2545*993b0882SAndroid Build Coastguard Worker       if (group->serialized_entity_data() != nullptr ||
2546*993b0882SAndroid Build Coastguard Worker           group->entity_data() != nullptr) {
2547*993b0882SAndroid Build Coastguard Worker         return true;
2548*993b0882SAndroid Build Coastguard Worker       }
2549*993b0882SAndroid Build Coastguard Worker     }
2550*993b0882SAndroid Build Coastguard Worker   }
2551*993b0882SAndroid Build Coastguard Worker   return false;
2552*993b0882SAndroid Build Coastguard Worker }
2553*993b0882SAndroid Build Coastguard Worker 
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2554*993b0882SAndroid Build Coastguard Worker bool Annotator::SerializedEntityDataFromRegexMatch(
2555*993b0882SAndroid Build Coastguard Worker     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2556*993b0882SAndroid Build Coastguard Worker     std::string* serialized_entity_data) const {
2557*993b0882SAndroid Build Coastguard Worker   if (!HasEntityData(pattern)) {
2558*993b0882SAndroid Build Coastguard Worker     serialized_entity_data->clear();
2559*993b0882SAndroid Build Coastguard Worker     return true;
2560*993b0882SAndroid Build Coastguard Worker   }
2561*993b0882SAndroid Build Coastguard Worker   TC3_CHECK(entity_data_builder_ != nullptr);
2562*993b0882SAndroid Build Coastguard Worker 
2563*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<MutableFlatbuffer> entity_data =
2564*993b0882SAndroid Build Coastguard Worker       entity_data_builder_->NewRoot();
2565*993b0882SAndroid Build Coastguard Worker 
2566*993b0882SAndroid Build Coastguard Worker   TC3_CHECK(entity_data != nullptr);
2567*993b0882SAndroid Build Coastguard Worker 
2568*993b0882SAndroid Build Coastguard Worker   // Set fixed entity data.
2569*993b0882SAndroid Build Coastguard Worker   if (pattern->serialized_entity_data() != nullptr) {
2570*993b0882SAndroid Build Coastguard Worker     entity_data->MergeFromSerializedFlatbuffer(
2571*993b0882SAndroid Build Coastguard Worker         StringPiece(pattern->serialized_entity_data()->c_str(),
2572*993b0882SAndroid Build Coastguard Worker                     pattern->serialized_entity_data()->size()));
2573*993b0882SAndroid Build Coastguard Worker   }
2574*993b0882SAndroid Build Coastguard Worker   if (pattern->entity_data() != nullptr) {
2575*993b0882SAndroid Build Coastguard Worker     entity_data->MergeFrom(
2576*993b0882SAndroid Build Coastguard Worker         reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2577*993b0882SAndroid Build Coastguard Worker   }
2578*993b0882SAndroid Build Coastguard Worker 
2579*993b0882SAndroid Build Coastguard Worker   // Add entity data from rule capturing groups.
2580*993b0882SAndroid Build Coastguard Worker   if (pattern->capturing_group() != nullptr) {
2581*993b0882SAndroid Build Coastguard Worker     const int num_groups = pattern->capturing_group()->size();
2582*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < num_groups; i++) {
2583*993b0882SAndroid Build Coastguard Worker       const CapturingGroup* group = pattern->capturing_group()->Get(i);
2584*993b0882SAndroid Build Coastguard Worker 
2585*993b0882SAndroid Build Coastguard Worker       // Check whether the group matched.
2586*993b0882SAndroid Build Coastguard Worker       Optional<std::string> group_match_text =
2587*993b0882SAndroid Build Coastguard Worker           GetCapturingGroupText(matcher, /*group_id=*/i);
2588*993b0882SAndroid Build Coastguard Worker       if (!group_match_text.has_value()) {
2589*993b0882SAndroid Build Coastguard Worker         continue;
2590*993b0882SAndroid Build Coastguard Worker       }
2591*993b0882SAndroid Build Coastguard Worker 
2592*993b0882SAndroid Build Coastguard Worker       // Set fixed entity data from capturing group match.
2593*993b0882SAndroid Build Coastguard Worker       if (group->serialized_entity_data() != nullptr) {
2594*993b0882SAndroid Build Coastguard Worker         entity_data->MergeFromSerializedFlatbuffer(
2595*993b0882SAndroid Build Coastguard Worker             StringPiece(group->serialized_entity_data()->c_str(),
2596*993b0882SAndroid Build Coastguard Worker                         group->serialized_entity_data()->size()));
2597*993b0882SAndroid Build Coastguard Worker       }
2598*993b0882SAndroid Build Coastguard Worker       if (group->entity_data() != nullptr) {
2599*993b0882SAndroid Build Coastguard Worker         entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2600*993b0882SAndroid Build Coastguard Worker             pattern->entity_data()));
2601*993b0882SAndroid Build Coastguard Worker       }
2602*993b0882SAndroid Build Coastguard Worker 
2603*993b0882SAndroid Build Coastguard Worker       // Set entity field from capturing group text.
2604*993b0882SAndroid Build Coastguard Worker       if (group->entity_field_path() != nullptr) {
2605*993b0882SAndroid Build Coastguard Worker         UnicodeText normalized_group_match_text =
2606*993b0882SAndroid Build Coastguard Worker             UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2607*993b0882SAndroid Build Coastguard Worker 
2608*993b0882SAndroid Build Coastguard Worker         // Apply normalization if specified.
2609*993b0882SAndroid Build Coastguard Worker         if (group->normalization_options() != nullptr) {
2610*993b0882SAndroid Build Coastguard Worker           normalized_group_match_text =
2611*993b0882SAndroid Build Coastguard Worker               NormalizeText(*unilib_, group->normalization_options(),
2612*993b0882SAndroid Build Coastguard Worker                             normalized_group_match_text);
2613*993b0882SAndroid Build Coastguard Worker         }
2614*993b0882SAndroid Build Coastguard Worker 
2615*993b0882SAndroid Build Coastguard Worker         if (!entity_data->ParseAndSet(
2616*993b0882SAndroid Build Coastguard Worker                 group->entity_field_path(),
2617*993b0882SAndroid Build Coastguard Worker                 normalized_group_match_text.ToUTF8String())) {
2618*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR)
2619*993b0882SAndroid Build Coastguard Worker               << "Could not set entity data from rule capturing group.";
2620*993b0882SAndroid Build Coastguard Worker           return false;
2621*993b0882SAndroid Build Coastguard Worker         }
2622*993b0882SAndroid Build Coastguard Worker       }
2623*993b0882SAndroid Build Coastguard Worker     }
2624*993b0882SAndroid Build Coastguard Worker   }
2625*993b0882SAndroid Build Coastguard Worker 
2626*993b0882SAndroid Build Coastguard Worker   *serialized_entity_data = entity_data->Serialize();
2627*993b0882SAndroid Build Coastguard Worker   return true;
2628*993b0882SAndroid Build Coastguard Worker }
2629*993b0882SAndroid Build Coastguard Worker 
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2630*993b0882SAndroid Build Coastguard Worker UnicodeText RemoveMoneySeparators(
2631*993b0882SAndroid Build Coastguard Worker     const std::unordered_set<char32>& decimal_separators,
2632*993b0882SAndroid Build Coastguard Worker     const UnicodeText& amount,
2633*993b0882SAndroid Build Coastguard Worker     UnicodeText::const_iterator it_decimal_separator) {
2634*993b0882SAndroid Build Coastguard Worker   UnicodeText whole_amount;
2635*993b0882SAndroid Build Coastguard Worker   for (auto it = amount.begin();
2636*993b0882SAndroid Build Coastguard Worker        it != amount.end() && it != it_decimal_separator; ++it) {
2637*993b0882SAndroid Build Coastguard Worker     if (std::find(decimal_separators.begin(), decimal_separators.end(),
2638*993b0882SAndroid Build Coastguard Worker                   static_cast<char32>(*it)) == decimal_separators.end()) {
2639*993b0882SAndroid Build Coastguard Worker       whole_amount.push_back(*it);
2640*993b0882SAndroid Build Coastguard Worker     }
2641*993b0882SAndroid Build Coastguard Worker   }
2642*993b0882SAndroid Build Coastguard Worker   return whole_amount;
2643*993b0882SAndroid Build Coastguard Worker }
2644*993b0882SAndroid Build Coastguard Worker 
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2645*993b0882SAndroid Build Coastguard Worker void Annotator::GetMoneyQuantityFromCapturingGroup(
2646*993b0882SAndroid Build Coastguard Worker     const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2647*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context_unicode, std::string* quantity,
2648*993b0882SAndroid Build Coastguard Worker     int* exponent) const {
2649*993b0882SAndroid Build Coastguard Worker   if (config->capturing_group() == nullptr) {
2650*993b0882SAndroid Build Coastguard Worker     *exponent = 0;
2651*993b0882SAndroid Build Coastguard Worker     return;
2652*993b0882SAndroid Build Coastguard Worker   }
2653*993b0882SAndroid Build Coastguard Worker 
2654*993b0882SAndroid Build Coastguard Worker   const int num_groups = config->capturing_group()->size();
2655*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_groups; i++) {
2656*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
2657*993b0882SAndroid Build Coastguard Worker     const int group_start = match->Start(i, &status);
2658*993b0882SAndroid Build Coastguard Worker     const int group_end = match->End(i, &status);
2659*993b0882SAndroid Build Coastguard Worker     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2660*993b0882SAndroid Build Coastguard Worker       continue;
2661*993b0882SAndroid Build Coastguard Worker     }
2662*993b0882SAndroid Build Coastguard Worker 
2663*993b0882SAndroid Build Coastguard Worker     *quantity =
2664*993b0882SAndroid Build Coastguard Worker         unilib_
2665*993b0882SAndroid Build Coastguard Worker             ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2666*993b0882SAndroid Build Coastguard Worker                                                  group_end, /*do_copy=*/false))
2667*993b0882SAndroid Build Coastguard Worker             .ToUTF8String();
2668*993b0882SAndroid Build Coastguard Worker 
2669*993b0882SAndroid Build Coastguard Worker     if (auto entry = model_->money_parsing_options()
2670*993b0882SAndroid Build Coastguard Worker                          ->quantities_name_to_exponent()
2671*993b0882SAndroid Build Coastguard Worker                          ->LookupByKey((*quantity).c_str())) {
2672*993b0882SAndroid Build Coastguard Worker       *exponent = entry->value();
2673*993b0882SAndroid Build Coastguard Worker       return;
2674*993b0882SAndroid Build Coastguard Worker     }
2675*993b0882SAndroid Build Coastguard Worker   }
2676*993b0882SAndroid Build Coastguard Worker   *exponent = 0;
2677*993b0882SAndroid Build Coastguard Worker }
2678*993b0882SAndroid Build Coastguard Worker 
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2679*993b0882SAndroid Build Coastguard Worker bool Annotator::ParseAndFillInMoneyAmount(
2680*993b0882SAndroid Build Coastguard Worker     std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2681*993b0882SAndroid Build Coastguard Worker     const RegexModel_::Pattern* config,
2682*993b0882SAndroid Build Coastguard Worker     const UnicodeText& context_unicode) const {
2683*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<EntityDataT> data =
2684*993b0882SAndroid Build Coastguard Worker       LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2685*993b0882SAndroid Build Coastguard Worker           *serialized_entity_data);
2686*993b0882SAndroid Build Coastguard Worker   if (data == nullptr) {
2687*993b0882SAndroid Build Coastguard Worker     if (model_->version() >= 706) {
2688*993b0882SAndroid Build Coastguard Worker       // This way of parsing money entity data is enabled for models newer than
2689*993b0882SAndroid Build Coastguard Worker       // v706, consequently logging errors only for them (b/156634162).
2690*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR)
2691*993b0882SAndroid Build Coastguard Worker           << "Data field is null when trying to parse Money Entity Data";
2692*993b0882SAndroid Build Coastguard Worker     }
2693*993b0882SAndroid Build Coastguard Worker     return false;
2694*993b0882SAndroid Build Coastguard Worker   }
2695*993b0882SAndroid Build Coastguard Worker   if (data->money->unnormalized_amount.empty()) {
2696*993b0882SAndroid Build Coastguard Worker     if (model_->version() >= 706) {
2697*993b0882SAndroid Build Coastguard Worker       // This way of parsing money entity data is enabled for models newer than
2698*993b0882SAndroid Build Coastguard Worker       // v706, consequently logging errors only for them (b/156634162).
2699*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR)
2700*993b0882SAndroid Build Coastguard Worker           << "Data unnormalized_amount is empty when trying to parse "
2701*993b0882SAndroid Build Coastguard Worker              "Money Entity Data";
2702*993b0882SAndroid Build Coastguard Worker     }
2703*993b0882SAndroid Build Coastguard Worker     return false;
2704*993b0882SAndroid Build Coastguard Worker   }
2705*993b0882SAndroid Build Coastguard Worker 
2706*993b0882SAndroid Build Coastguard Worker   UnicodeText amount =
2707*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2708*993b0882SAndroid Build Coastguard Worker   int separator_back_index = 0;
2709*993b0882SAndroid Build Coastguard Worker   auto it_decimal_separator = --amount.end();
2710*993b0882SAndroid Build Coastguard Worker   for (; it_decimal_separator != amount.begin();
2711*993b0882SAndroid Build Coastguard Worker        --it_decimal_separator, ++separator_back_index) {
2712*993b0882SAndroid Build Coastguard Worker     if (std::find(money_separators_.begin(), money_separators_.end(),
2713*993b0882SAndroid Build Coastguard Worker                   static_cast<char32>(*it_decimal_separator)) !=
2714*993b0882SAndroid Build Coastguard Worker         money_separators_.end()) {
2715*993b0882SAndroid Build Coastguard Worker       break;
2716*993b0882SAndroid Build Coastguard Worker     }
2717*993b0882SAndroid Build Coastguard Worker   }
2718*993b0882SAndroid Build Coastguard Worker 
2719*993b0882SAndroid Build Coastguard Worker   // If there are 3 digits after the last separator, we consider that a
2720*993b0882SAndroid Build Coastguard Worker   // thousands separator => the number is an int (e.g. 1.234 is considered int).
2721*993b0882SAndroid Build Coastguard Worker   // If there is no separator in number, also that number is an int.
2722*993b0882SAndroid Build Coastguard Worker   if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2723*993b0882SAndroid Build Coastguard Worker     it_decimal_separator = amount.end();
2724*993b0882SAndroid Build Coastguard Worker   }
2725*993b0882SAndroid Build Coastguard Worker 
2726*993b0882SAndroid Build Coastguard Worker   if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2727*993b0882SAndroid Build Coastguard Worker                                                  it_decimal_separator),
2728*993b0882SAndroid Build Coastguard Worker                            &data->money->amount_whole_part)) {
2729*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2730*993b0882SAndroid Build Coastguard Worker                       "amount: "
2731*993b0882SAndroid Build Coastguard Worker                    << data->money->unnormalized_amount;
2732*993b0882SAndroid Build Coastguard Worker     return false;
2733*993b0882SAndroid Build Coastguard Worker   }
2734*993b0882SAndroid Build Coastguard Worker 
2735*993b0882SAndroid Build Coastguard Worker   if (it_decimal_separator == amount.end()) {
2736*993b0882SAndroid Build Coastguard Worker     data->money->amount_decimal_part = 0;
2737*993b0882SAndroid Build Coastguard Worker     data->money->nanos = 0;
2738*993b0882SAndroid Build Coastguard Worker   } else {
2739*993b0882SAndroid Build Coastguard Worker     const int amount_codepoints_size = amount.size_codepoints();
2740*993b0882SAndroid Build Coastguard Worker     const UnicodeText decimal_part = UnicodeText::Substring(
2741*993b0882SAndroid Build Coastguard Worker         amount, amount_codepoints_size - separator_back_index,
2742*993b0882SAndroid Build Coastguard Worker         amount_codepoints_size, /*do_copy=*/false);
2743*993b0882SAndroid Build Coastguard Worker     if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2744*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2745*993b0882SAndroid Build Coastguard Worker                         "the amount: "
2746*993b0882SAndroid Build Coastguard Worker                      << data->money->unnormalized_amount;
2747*993b0882SAndroid Build Coastguard Worker       return false;
2748*993b0882SAndroid Build Coastguard Worker     }
2749*993b0882SAndroid Build Coastguard Worker     data->money->nanos = data->money->amount_decimal_part *
2750*993b0882SAndroid Build Coastguard Worker                          pow(10, 9 - decimal_part.size_codepoints());
2751*993b0882SAndroid Build Coastguard Worker   }
2752*993b0882SAndroid Build Coastguard Worker 
2753*993b0882SAndroid Build Coastguard Worker   if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2754*993b0882SAndroid Build Coastguard Worker       nullptr) {
2755*993b0882SAndroid Build Coastguard Worker     int quantity_exponent;
2756*993b0882SAndroid Build Coastguard Worker     std::string quantity;
2757*993b0882SAndroid Build Coastguard Worker     GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2758*993b0882SAndroid Build Coastguard Worker                                        &quantity, &quantity_exponent);
2759*993b0882SAndroid Build Coastguard Worker     if (quantity_exponent > 0 && quantity_exponent <= 9) {
2760*993b0882SAndroid Build Coastguard Worker       const double amount_whole_part =
2761*993b0882SAndroid Build Coastguard Worker           data->money->amount_whole_part * pow(10, quantity_exponent) +
2762*993b0882SAndroid Build Coastguard Worker           data->money->nanos / pow(10, 9 - quantity_exponent);
2763*993b0882SAndroid Build Coastguard Worker       // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2764*993b0882SAndroid Build Coastguard Worker       // (and `std::numeric_limits<int>::max()` to
2765*993b0882SAndroid Build Coastguard Worker       // `std::numeric_limits<int64>::max()`).
2766*993b0882SAndroid Build Coastguard Worker       if (amount_whole_part < std::numeric_limits<int>::max()) {
2767*993b0882SAndroid Build Coastguard Worker         data->money->amount_whole_part = amount_whole_part;
2768*993b0882SAndroid Build Coastguard Worker         data->money->nanos = data->money->nanos %
2769*993b0882SAndroid Build Coastguard Worker                              static_cast<int>(pow(10, 9 - quantity_exponent)) *
2770*993b0882SAndroid Build Coastguard Worker                              pow(10, quantity_exponent);
2771*993b0882SAndroid Build Coastguard Worker       }
2772*993b0882SAndroid Build Coastguard Worker     }
2773*993b0882SAndroid Build Coastguard Worker     if (quantity_exponent > 0) {
2774*993b0882SAndroid Build Coastguard Worker       data->money->unnormalized_amount = strings::JoinStrings(
2775*993b0882SAndroid Build Coastguard Worker           " ", {data->money->unnormalized_amount, quantity});
2776*993b0882SAndroid Build Coastguard Worker     }
2777*993b0882SAndroid Build Coastguard Worker   }
2778*993b0882SAndroid Build Coastguard Worker 
2779*993b0882SAndroid Build Coastguard Worker   *serialized_entity_data =
2780*993b0882SAndroid Build Coastguard Worker       PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2781*993b0882SAndroid Build Coastguard Worker   return true;
2782*993b0882SAndroid Build Coastguard Worker }
2783*993b0882SAndroid Build Coastguard Worker 
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2784*993b0882SAndroid Build Coastguard Worker bool Annotator::IsAnyModelEntityTypeEnabled(
2785*993b0882SAndroid Build Coastguard Worker     const EnabledEntityTypes& is_entity_type_enabled) const {
2786*993b0882SAndroid Build Coastguard Worker   if (model_->classification_feature_options() == nullptr ||
2787*993b0882SAndroid Build Coastguard Worker       model_->classification_feature_options()->collections() == nullptr) {
2788*993b0882SAndroid Build Coastguard Worker     return false;
2789*993b0882SAndroid Build Coastguard Worker   }
2790*993b0882SAndroid Build Coastguard Worker   for (int i = 0;
2791*993b0882SAndroid Build Coastguard Worker        i < model_->classification_feature_options()->collections()->size();
2792*993b0882SAndroid Build Coastguard Worker        i++) {
2793*993b0882SAndroid Build Coastguard Worker     if (is_entity_type_enabled(model_->classification_feature_options()
2794*993b0882SAndroid Build Coastguard Worker                                    ->collections()
2795*993b0882SAndroid Build Coastguard Worker                                    ->Get(i)
2796*993b0882SAndroid Build Coastguard Worker                                    ->str())) {
2797*993b0882SAndroid Build Coastguard Worker       return true;
2798*993b0882SAndroid Build Coastguard Worker     }
2799*993b0882SAndroid Build Coastguard Worker   }
2800*993b0882SAndroid Build Coastguard Worker   return false;
2801*993b0882SAndroid Build Coastguard Worker }
2802*993b0882SAndroid Build Coastguard Worker 
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2803*993b0882SAndroid Build Coastguard Worker bool Annotator::IsAnyRegexEntityTypeEnabled(
2804*993b0882SAndroid Build Coastguard Worker     const EnabledEntityTypes& is_entity_type_enabled) const {
2805*993b0882SAndroid Build Coastguard Worker   if (model_->regex_model() == nullptr ||
2806*993b0882SAndroid Build Coastguard Worker       model_->regex_model()->patterns() == nullptr) {
2807*993b0882SAndroid Build Coastguard Worker     return false;
2808*993b0882SAndroid Build Coastguard Worker   }
2809*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2810*993b0882SAndroid Build Coastguard Worker     if (is_entity_type_enabled(model_->regex_model()
2811*993b0882SAndroid Build Coastguard Worker                                    ->patterns()
2812*993b0882SAndroid Build Coastguard Worker                                    ->Get(i)
2813*993b0882SAndroid Build Coastguard Worker                                    ->collection_name()
2814*993b0882SAndroid Build Coastguard Worker                                    ->str())) {
2815*993b0882SAndroid Build Coastguard Worker       return true;
2816*993b0882SAndroid Build Coastguard Worker     }
2817*993b0882SAndroid Build Coastguard Worker   }
2818*993b0882SAndroid Build Coastguard Worker   return false;
2819*993b0882SAndroid Build Coastguard Worker }
2820*993b0882SAndroid Build Coastguard Worker 
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2821*993b0882SAndroid Build Coastguard Worker bool Annotator::IsAnyPodNerEntityTypeEnabled(
2822*993b0882SAndroid Build Coastguard Worker     const EnabledEntityTypes& is_entity_type_enabled) const {
2823*993b0882SAndroid Build Coastguard Worker   if (pod_ner_annotator_ == nullptr) {
2824*993b0882SAndroid Build Coastguard Worker     return false;
2825*993b0882SAndroid Build Coastguard Worker   }
2826*993b0882SAndroid Build Coastguard Worker 
2827*993b0882SAndroid Build Coastguard Worker   for (const std::string& collection :
2828*993b0882SAndroid Build Coastguard Worker        pod_ner_annotator_->GetSupportedCollections()) {
2829*993b0882SAndroid Build Coastguard Worker     if (is_entity_type_enabled(collection)) {
2830*993b0882SAndroid Build Coastguard Worker       return true;
2831*993b0882SAndroid Build Coastguard Worker     }
2832*993b0882SAndroid Build Coastguard Worker   }
2833*993b0882SAndroid Build Coastguard Worker   return false;
2834*993b0882SAndroid Build Coastguard Worker }
2835*993b0882SAndroid Build Coastguard Worker 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2836*993b0882SAndroid Build Coastguard Worker bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2837*993b0882SAndroid Build Coastguard Worker                            const std::vector<int>& rules,
2838*993b0882SAndroid Build Coastguard Worker                            bool is_serialized_entity_data_enabled,
2839*993b0882SAndroid Build Coastguard Worker                            const EnabledEntityTypes& enabled_entity_types,
2840*993b0882SAndroid Build Coastguard Worker                            const AnnotationUsecase& annotation_usecase,
2841*993b0882SAndroid Build Coastguard Worker                            std::vector<AnnotatedSpan>* result) const {
2842*993b0882SAndroid Build Coastguard Worker   for (int pattern_id : rules) {
2843*993b0882SAndroid Build Coastguard Worker     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2844*993b0882SAndroid Build Coastguard Worker     if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2845*993b0882SAndroid Build Coastguard Worker         annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2846*993b0882SAndroid Build Coastguard Worker       // No regex annotation type has been requested, skip regex annotation.
2847*993b0882SAndroid Build Coastguard Worker       continue;
2848*993b0882SAndroid Build Coastguard Worker     }
2849*993b0882SAndroid Build Coastguard Worker     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2850*993b0882SAndroid Build Coastguard Worker     if (!matcher) {
2851*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2852*993b0882SAndroid Build Coastguard Worker                      << pattern_id;
2853*993b0882SAndroid Build Coastguard Worker       return false;
2854*993b0882SAndroid Build Coastguard Worker     }
2855*993b0882SAndroid Build Coastguard Worker 
2856*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
2857*993b0882SAndroid Build Coastguard Worker     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2858*993b0882SAndroid Build Coastguard Worker       if (regex_pattern.config->verification_options()) {
2859*993b0882SAndroid Build Coastguard Worker         if (!VerifyRegexMatchCandidate(
2860*993b0882SAndroid Build Coastguard Worker                 context_unicode.ToUTF8String(),
2861*993b0882SAndroid Build Coastguard Worker                 regex_pattern.config->verification_options(),
2862*993b0882SAndroid Build Coastguard Worker                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2863*993b0882SAndroid Build Coastguard Worker           continue;
2864*993b0882SAndroid Build Coastguard Worker         }
2865*993b0882SAndroid Build Coastguard Worker       }
2866*993b0882SAndroid Build Coastguard Worker 
2867*993b0882SAndroid Build Coastguard Worker       std::string serialized_entity_data;
2868*993b0882SAndroid Build Coastguard Worker       if (is_serialized_entity_data_enabled) {
2869*993b0882SAndroid Build Coastguard Worker         if (!SerializedEntityDataFromRegexMatch(
2870*993b0882SAndroid Build Coastguard Worker                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2871*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR) << "Could not get entity data.";
2872*993b0882SAndroid Build Coastguard Worker           return false;
2873*993b0882SAndroid Build Coastguard Worker         }
2874*993b0882SAndroid Build Coastguard Worker 
2875*993b0882SAndroid Build Coastguard Worker         // Further parsing of money amount. Need this since regexes cannot have
2876*993b0882SAndroid Build Coastguard Worker         // empty groups that fill in entity data (amount_decimal_part and
2877*993b0882SAndroid Build Coastguard Worker         // quantity might be empty groups).
2878*993b0882SAndroid Build Coastguard Worker         if (regex_pattern.config->collection_name()->str() ==
2879*993b0882SAndroid Build Coastguard Worker             Collections::Money()) {
2880*993b0882SAndroid Build Coastguard Worker           if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2881*993b0882SAndroid Build Coastguard Worker                                          regex_pattern.config,
2882*993b0882SAndroid Build Coastguard Worker                                          context_unicode)) {
2883*993b0882SAndroid Build Coastguard Worker             if (model_->version() >= 706) {
2884*993b0882SAndroid Build Coastguard Worker               // This way of parsing money entity data is enabled for models
2885*993b0882SAndroid Build Coastguard Worker               // newer than v706 => logging errors only for them (b/156634162).
2886*993b0882SAndroid Build Coastguard Worker               TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2887*993b0882SAndroid Build Coastguard Worker             }
2888*993b0882SAndroid Build Coastguard Worker           }
2889*993b0882SAndroid Build Coastguard Worker         }
2890*993b0882SAndroid Build Coastguard Worker       }
2891*993b0882SAndroid Build Coastguard Worker 
2892*993b0882SAndroid Build Coastguard Worker       result->emplace_back();
2893*993b0882SAndroid Build Coastguard Worker 
2894*993b0882SAndroid Build Coastguard Worker       // Selection/annotation regular expressions need to specify a capturing
2895*993b0882SAndroid Build Coastguard Worker       // group specifying the selection.
2896*993b0882SAndroid Build Coastguard Worker       result->back().span =
2897*993b0882SAndroid Build Coastguard Worker           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2898*993b0882SAndroid Build Coastguard Worker 
2899*993b0882SAndroid Build Coastguard Worker       result->back().classification = {
2900*993b0882SAndroid Build Coastguard Worker           {regex_pattern.config->collection_name()->str(),
2901*993b0882SAndroid Build Coastguard Worker            regex_pattern.config->target_classification_score(),
2902*993b0882SAndroid Build Coastguard Worker            regex_pattern.config->priority_score()}};
2903*993b0882SAndroid Build Coastguard Worker 
2904*993b0882SAndroid Build Coastguard Worker       result->back().classification[0].serialized_entity_data =
2905*993b0882SAndroid Build Coastguard Worker           serialized_entity_data;
2906*993b0882SAndroid Build Coastguard Worker     }
2907*993b0882SAndroid Build Coastguard Worker   }
2908*993b0882SAndroid Build Coastguard Worker   return true;
2909*993b0882SAndroid Build Coastguard Worker }
2910*993b0882SAndroid Build Coastguard Worker 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2911*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2912*993b0882SAndroid Build Coastguard Worker                            tflite::Interpreter* selection_interpreter,
2913*993b0882SAndroid Build Coastguard Worker                            const CachedFeatures& cached_features,
2914*993b0882SAndroid Build Coastguard Worker                            std::vector<TokenSpan>* chunks) const {
2915*993b0882SAndroid Build Coastguard Worker   const int max_selection_span =
2916*993b0882SAndroid Build Coastguard Worker       selection_feature_processor_->GetOptions()->max_selection_span();
2917*993b0882SAndroid Build Coastguard Worker   // The inference span is the span of interest expanded to include
2918*993b0882SAndroid Build Coastguard Worker   // max_selection_span tokens on either side, which is how far a selection can
2919*993b0882SAndroid Build Coastguard Worker   // stretch from the click.
2920*993b0882SAndroid Build Coastguard Worker   const TokenSpan inference_span =
2921*993b0882SAndroid Build Coastguard Worker       IntersectTokenSpans(span_of_interest.Expand(
2922*993b0882SAndroid Build Coastguard Worker                               /*num_tokens_left=*/max_selection_span,
2923*993b0882SAndroid Build Coastguard Worker                               /*num_tokens_right=*/max_selection_span),
2924*993b0882SAndroid Build Coastguard Worker                           {0, num_tokens});
2925*993b0882SAndroid Build Coastguard Worker 
2926*993b0882SAndroid Build Coastguard Worker   std::vector<ScoredChunk> scored_chunks;
2927*993b0882SAndroid Build Coastguard Worker   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2928*993b0882SAndroid Build Coastguard Worker       selection_feature_processor_->GetOptions()
2929*993b0882SAndroid Build Coastguard Worker           ->bounds_sensitive_features()
2930*993b0882SAndroid Build Coastguard Worker           ->enabled()) {
2931*993b0882SAndroid Build Coastguard Worker     if (!ModelBoundsSensitiveScoreChunks(
2932*993b0882SAndroid Build Coastguard Worker             num_tokens, span_of_interest, inference_span, cached_features,
2933*993b0882SAndroid Build Coastguard Worker             selection_interpreter, &scored_chunks)) {
2934*993b0882SAndroid Build Coastguard Worker       return false;
2935*993b0882SAndroid Build Coastguard Worker     }
2936*993b0882SAndroid Build Coastguard Worker   } else {
2937*993b0882SAndroid Build Coastguard Worker     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2938*993b0882SAndroid Build Coastguard Worker                                       cached_features, selection_interpreter,
2939*993b0882SAndroid Build Coastguard Worker                                       &scored_chunks)) {
2940*993b0882SAndroid Build Coastguard Worker       return false;
2941*993b0882SAndroid Build Coastguard Worker     }
2942*993b0882SAndroid Build Coastguard Worker   }
2943*993b0882SAndroid Build Coastguard Worker   std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
2944*993b0882SAndroid Build Coastguard Worker                    [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2945*993b0882SAndroid Build Coastguard Worker                      return lhs.score < rhs.score;
2946*993b0882SAndroid Build Coastguard Worker                    });
2947*993b0882SAndroid Build Coastguard Worker 
2948*993b0882SAndroid Build Coastguard Worker   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2949*993b0882SAndroid Build Coastguard Worker   // them greedily as long as they do not overlap with any previously picked
2950*993b0882SAndroid Build Coastguard Worker   // chunks.
2951*993b0882SAndroid Build Coastguard Worker   std::vector<bool> token_used(inference_span.Size());
2952*993b0882SAndroid Build Coastguard Worker   chunks->clear();
2953*993b0882SAndroid Build Coastguard Worker   for (const ScoredChunk& scored_chunk : scored_chunks) {
2954*993b0882SAndroid Build Coastguard Worker     bool feasible = true;
2955*993b0882SAndroid Build Coastguard Worker     for (int i = scored_chunk.token_span.first;
2956*993b0882SAndroid Build Coastguard Worker          i < scored_chunk.token_span.second; ++i) {
2957*993b0882SAndroid Build Coastguard Worker       if (token_used[i - inference_span.first]) {
2958*993b0882SAndroid Build Coastguard Worker         feasible = false;
2959*993b0882SAndroid Build Coastguard Worker         break;
2960*993b0882SAndroid Build Coastguard Worker       }
2961*993b0882SAndroid Build Coastguard Worker     }
2962*993b0882SAndroid Build Coastguard Worker 
2963*993b0882SAndroid Build Coastguard Worker     if (!feasible) {
2964*993b0882SAndroid Build Coastguard Worker       continue;
2965*993b0882SAndroid Build Coastguard Worker     }
2966*993b0882SAndroid Build Coastguard Worker 
2967*993b0882SAndroid Build Coastguard Worker     for (int i = scored_chunk.token_span.first;
2968*993b0882SAndroid Build Coastguard Worker          i < scored_chunk.token_span.second; ++i) {
2969*993b0882SAndroid Build Coastguard Worker       token_used[i - inference_span.first] = true;
2970*993b0882SAndroid Build Coastguard Worker     }
2971*993b0882SAndroid Build Coastguard Worker 
2972*993b0882SAndroid Build Coastguard Worker     chunks->push_back(scored_chunk.token_span);
2973*993b0882SAndroid Build Coastguard Worker   }
2974*993b0882SAndroid Build Coastguard Worker 
2975*993b0882SAndroid Build Coastguard Worker   std::stable_sort(chunks->begin(), chunks->end());
2976*993b0882SAndroid Build Coastguard Worker 
2977*993b0882SAndroid Build Coastguard Worker   return true;
2978*993b0882SAndroid Build Coastguard Worker }
2979*993b0882SAndroid Build Coastguard Worker 
2980*993b0882SAndroid Build Coastguard Worker namespace {
2981*993b0882SAndroid Build Coastguard Worker // Updates the value at the given key in the map to maximum of the current value
2982*993b0882SAndroid Build Coastguard Worker // and the given value, or simply inserts the value if the key is not yet there.
2983*993b0882SAndroid Build Coastguard Worker template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2984*993b0882SAndroid Build Coastguard Worker void UpdateMax(Map* map, typename Map::key_type key,
2985*993b0882SAndroid Build Coastguard Worker                typename Map::mapped_type value) {
2986*993b0882SAndroid Build Coastguard Worker   const auto it = map->find(key);
2987*993b0882SAndroid Build Coastguard Worker   if (it != map->end()) {
2988*993b0882SAndroid Build Coastguard Worker     it->second = std::max(it->second, value);
2989*993b0882SAndroid Build Coastguard Worker   } else {
2990*993b0882SAndroid Build Coastguard Worker     (*map)[key] = value;
2991*993b0882SAndroid Build Coastguard Worker   }
2992*993b0882SAndroid Build Coastguard Worker }
2993*993b0882SAndroid Build Coastguard Worker }  // namespace
2994*993b0882SAndroid Build Coastguard Worker 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2995*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelClickContextScoreChunks(
2996*993b0882SAndroid Build Coastguard Worker     int num_tokens, const TokenSpan& span_of_interest,
2997*993b0882SAndroid Build Coastguard Worker     const CachedFeatures& cached_features,
2998*993b0882SAndroid Build Coastguard Worker     tflite::Interpreter* selection_interpreter,
2999*993b0882SAndroid Build Coastguard Worker     std::vector<ScoredChunk>* scored_chunks) const {
3000*993b0882SAndroid Build Coastguard Worker   const int max_batch_size = model_->selection_options()->batch_size();
3001*993b0882SAndroid Build Coastguard Worker 
3002*993b0882SAndroid Build Coastguard Worker   std::vector<float> all_features;
3003*993b0882SAndroid Build Coastguard Worker   std::map<TokenSpan, float> chunk_scores;
3004*993b0882SAndroid Build Coastguard Worker   for (int batch_start = span_of_interest.first;
3005*993b0882SAndroid Build Coastguard Worker        batch_start < span_of_interest.second; batch_start += max_batch_size) {
3006*993b0882SAndroid Build Coastguard Worker     const int batch_end =
3007*993b0882SAndroid Build Coastguard Worker         std::min(batch_start + max_batch_size, span_of_interest.second);
3008*993b0882SAndroid Build Coastguard Worker 
3009*993b0882SAndroid Build Coastguard Worker     // Prepare features for the whole batch.
3010*993b0882SAndroid Build Coastguard Worker     all_features.clear();
3011*993b0882SAndroid Build Coastguard Worker     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3012*993b0882SAndroid Build Coastguard Worker     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3013*993b0882SAndroid Build Coastguard Worker       cached_features.AppendClickContextFeaturesForClick(click_pos,
3014*993b0882SAndroid Build Coastguard Worker                                                          &all_features);
3015*993b0882SAndroid Build Coastguard Worker     }
3016*993b0882SAndroid Build Coastguard Worker 
3017*993b0882SAndroid Build Coastguard Worker     // Run batched inference.
3018*993b0882SAndroid Build Coastguard Worker     const int batch_size = batch_end - batch_start;
3019*993b0882SAndroid Build Coastguard Worker     const int features_size = cached_features.OutputFeaturesSize();
3020*993b0882SAndroid Build Coastguard Worker     TensorView<float> logits = selection_executor_->ComputeLogits(
3021*993b0882SAndroid Build Coastguard Worker         TensorView<float>(all_features.data(), {batch_size, features_size}),
3022*993b0882SAndroid Build Coastguard Worker         selection_interpreter);
3023*993b0882SAndroid Build Coastguard Worker     if (!logits.is_valid()) {
3024*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Couldn't compute logits.";
3025*993b0882SAndroid Build Coastguard Worker       return false;
3026*993b0882SAndroid Build Coastguard Worker     }
3027*993b0882SAndroid Build Coastguard Worker     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3028*993b0882SAndroid Build Coastguard Worker         logits.dim(1) !=
3029*993b0882SAndroid Build Coastguard Worker             selection_feature_processor_->GetSelectionLabelCount()) {
3030*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Mismatching output.";
3031*993b0882SAndroid Build Coastguard Worker       return false;
3032*993b0882SAndroid Build Coastguard Worker     }
3033*993b0882SAndroid Build Coastguard Worker 
3034*993b0882SAndroid Build Coastguard Worker     // Save results.
3035*993b0882SAndroid Build Coastguard Worker     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3036*993b0882SAndroid Build Coastguard Worker       const std::vector<float> scores = ComputeSoftmax(
3037*993b0882SAndroid Build Coastguard Worker           logits.data() + logits.dim(1) * (click_pos - batch_start),
3038*993b0882SAndroid Build Coastguard Worker           logits.dim(1));
3039*993b0882SAndroid Build Coastguard Worker       for (int j = 0;
3040*993b0882SAndroid Build Coastguard Worker            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3041*993b0882SAndroid Build Coastguard Worker         TokenSpan relative_token_span;
3042*993b0882SAndroid Build Coastguard Worker         if (!selection_feature_processor_->LabelToTokenSpan(
3043*993b0882SAndroid Build Coastguard Worker                 j, &relative_token_span)) {
3044*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3045*993b0882SAndroid Build Coastguard Worker           return false;
3046*993b0882SAndroid Build Coastguard Worker         }
3047*993b0882SAndroid Build Coastguard Worker         const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3048*993b0882SAndroid Build Coastguard Worker             relative_token_span.first, relative_token_span.second);
3049*993b0882SAndroid Build Coastguard Worker         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3050*993b0882SAndroid Build Coastguard Worker           UpdateMax(&chunk_scores, candidate_span, scores[j]);
3051*993b0882SAndroid Build Coastguard Worker         }
3052*993b0882SAndroid Build Coastguard Worker       }
3053*993b0882SAndroid Build Coastguard Worker     }
3054*993b0882SAndroid Build Coastguard Worker   }
3055*993b0882SAndroid Build Coastguard Worker 
3056*993b0882SAndroid Build Coastguard Worker   scored_chunks->clear();
3057*993b0882SAndroid Build Coastguard Worker   scored_chunks->reserve(chunk_scores.size());
3058*993b0882SAndroid Build Coastguard Worker   for (const auto& entry : chunk_scores) {
3059*993b0882SAndroid Build Coastguard Worker     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3060*993b0882SAndroid Build Coastguard Worker   }
3061*993b0882SAndroid Build Coastguard Worker 
3062*993b0882SAndroid Build Coastguard Worker   return true;
3063*993b0882SAndroid Build Coastguard Worker }
3064*993b0882SAndroid Build Coastguard Worker 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3065*993b0882SAndroid Build Coastguard Worker bool Annotator::ModelBoundsSensitiveScoreChunks(
3066*993b0882SAndroid Build Coastguard Worker     int num_tokens, const TokenSpan& span_of_interest,
3067*993b0882SAndroid Build Coastguard Worker     const TokenSpan& inference_span, const CachedFeatures& cached_features,
3068*993b0882SAndroid Build Coastguard Worker     tflite::Interpreter* selection_interpreter,
3069*993b0882SAndroid Build Coastguard Worker     std::vector<ScoredChunk>* scored_chunks) const {
3070*993b0882SAndroid Build Coastguard Worker   const int max_selection_span =
3071*993b0882SAndroid Build Coastguard Worker       selection_feature_processor_->GetOptions()->max_selection_span();
3072*993b0882SAndroid Build Coastguard Worker   const int max_chunk_length = selection_feature_processor_->GetOptions()
3073*993b0882SAndroid Build Coastguard Worker                                        ->selection_reduced_output_space()
3074*993b0882SAndroid Build Coastguard Worker                                    ? max_selection_span + 1
3075*993b0882SAndroid Build Coastguard Worker                                    : 2 * max_selection_span + 1;
3076*993b0882SAndroid Build Coastguard Worker   const bool score_single_token_spans_as_zero =
3077*993b0882SAndroid Build Coastguard Worker       selection_feature_processor_->GetOptions()
3078*993b0882SAndroid Build Coastguard Worker           ->bounds_sensitive_features()
3079*993b0882SAndroid Build Coastguard Worker           ->score_single_token_spans_as_zero();
3080*993b0882SAndroid Build Coastguard Worker 
3081*993b0882SAndroid Build Coastguard Worker   scored_chunks->clear();
3082*993b0882SAndroid Build Coastguard Worker   if (score_single_token_spans_as_zero) {
3083*993b0882SAndroid Build Coastguard Worker     scored_chunks->reserve(span_of_interest.Size());
3084*993b0882SAndroid Build Coastguard Worker   }
3085*993b0882SAndroid Build Coastguard Worker 
3086*993b0882SAndroid Build Coastguard Worker   // Prepare all chunk candidates into one batch:
3087*993b0882SAndroid Build Coastguard Worker   //   - Are contained in the inference span
3088*993b0882SAndroid Build Coastguard Worker   //   - Have a non-empty intersection with the span of interest
3089*993b0882SAndroid Build Coastguard Worker   //   - Are at least one token long
3090*993b0882SAndroid Build Coastguard Worker   //   - Are not longer than the maximum chunk length
3091*993b0882SAndroid Build Coastguard Worker   std::vector<TokenSpan> candidate_spans;
3092*993b0882SAndroid Build Coastguard Worker   for (int start = inference_span.first; start < span_of_interest.second;
3093*993b0882SAndroid Build Coastguard Worker        ++start) {
3094*993b0882SAndroid Build Coastguard Worker     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3095*993b0882SAndroid Build Coastguard Worker     for (int end = leftmost_end_index;
3096*993b0882SAndroid Build Coastguard Worker          end <= inference_span.second && end - start <= max_chunk_length;
3097*993b0882SAndroid Build Coastguard Worker          ++end) {
3098*993b0882SAndroid Build Coastguard Worker       const TokenSpan candidate_span = {start, end};
3099*993b0882SAndroid Build Coastguard Worker       if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3100*993b0882SAndroid Build Coastguard Worker         // Do not include the single token span in the batch, add a zero score
3101*993b0882SAndroid Build Coastguard Worker         // for it directly to the output.
3102*993b0882SAndroid Build Coastguard Worker         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3103*993b0882SAndroid Build Coastguard Worker       } else {
3104*993b0882SAndroid Build Coastguard Worker         candidate_spans.push_back(candidate_span);
3105*993b0882SAndroid Build Coastguard Worker       }
3106*993b0882SAndroid Build Coastguard Worker     }
3107*993b0882SAndroid Build Coastguard Worker   }
3108*993b0882SAndroid Build Coastguard Worker 
3109*993b0882SAndroid Build Coastguard Worker   const int max_batch_size = model_->selection_options()->batch_size();
3110*993b0882SAndroid Build Coastguard Worker 
3111*993b0882SAndroid Build Coastguard Worker   std::vector<float> all_features;
3112*993b0882SAndroid Build Coastguard Worker   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3113*993b0882SAndroid Build Coastguard Worker   for (int batch_start = 0; batch_start < candidate_spans.size();
3114*993b0882SAndroid Build Coastguard Worker        batch_start += max_batch_size) {
3115*993b0882SAndroid Build Coastguard Worker     const int batch_end = std::min(batch_start + max_batch_size,
3116*993b0882SAndroid Build Coastguard Worker                                    static_cast<int>(candidate_spans.size()));
3117*993b0882SAndroid Build Coastguard Worker 
3118*993b0882SAndroid Build Coastguard Worker     // Prepare features for the whole batch.
3119*993b0882SAndroid Build Coastguard Worker     all_features.clear();
3120*993b0882SAndroid Build Coastguard Worker     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3121*993b0882SAndroid Build Coastguard Worker     for (int i = batch_start; i < batch_end; ++i) {
3122*993b0882SAndroid Build Coastguard Worker       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3123*993b0882SAndroid Build Coastguard Worker                                                            &all_features);
3124*993b0882SAndroid Build Coastguard Worker     }
3125*993b0882SAndroid Build Coastguard Worker 
3126*993b0882SAndroid Build Coastguard Worker     // Run batched inference.
3127*993b0882SAndroid Build Coastguard Worker     const int batch_size = batch_end - batch_start;
3128*993b0882SAndroid Build Coastguard Worker     const int features_size = cached_features.OutputFeaturesSize();
3129*993b0882SAndroid Build Coastguard Worker     TensorView<float> logits = selection_executor_->ComputeLogits(
3130*993b0882SAndroid Build Coastguard Worker         TensorView<float>(all_features.data(), {batch_size, features_size}),
3131*993b0882SAndroid Build Coastguard Worker         selection_interpreter);
3132*993b0882SAndroid Build Coastguard Worker     if (!logits.is_valid()) {
3133*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Couldn't compute logits.";
3134*993b0882SAndroid Build Coastguard Worker       return false;
3135*993b0882SAndroid Build Coastguard Worker     }
3136*993b0882SAndroid Build Coastguard Worker     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3137*993b0882SAndroid Build Coastguard Worker         logits.dim(1) != 1) {
3138*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Mismatching output.";
3139*993b0882SAndroid Build Coastguard Worker       return false;
3140*993b0882SAndroid Build Coastguard Worker     }
3141*993b0882SAndroid Build Coastguard Worker 
3142*993b0882SAndroid Build Coastguard Worker     // Save results.
3143*993b0882SAndroid Build Coastguard Worker     for (int i = batch_start; i < batch_end; ++i) {
3144*993b0882SAndroid Build Coastguard Worker       scored_chunks->push_back(
3145*993b0882SAndroid Build Coastguard Worker           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3146*993b0882SAndroid Build Coastguard Worker     }
3147*993b0882SAndroid Build Coastguard Worker   }
3148*993b0882SAndroid Build Coastguard Worker 
3149*993b0882SAndroid Build Coastguard Worker   return true;
3150*993b0882SAndroid Build Coastguard Worker }
3151*993b0882SAndroid Build Coastguard Worker 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3152*993b0882SAndroid Build Coastguard Worker bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3153*993b0882SAndroid Build Coastguard Worker                               int64 reference_time_ms_utc,
3154*993b0882SAndroid Build Coastguard Worker                               const std::string& reference_timezone,
3155*993b0882SAndroid Build Coastguard Worker                               const std::string& locales, ModeFlag mode,
3156*993b0882SAndroid Build Coastguard Worker                               AnnotationUsecase annotation_usecase,
3157*993b0882SAndroid Build Coastguard Worker                               bool is_serialized_entity_data_enabled,
3158*993b0882SAndroid Build Coastguard Worker                               std::vector<AnnotatedSpan>* result) const {
3159*993b0882SAndroid Build Coastguard Worker   if (!datetime_parser_) {
3160*993b0882SAndroid Build Coastguard Worker     return true;
3161*993b0882SAndroid Build Coastguard Worker   }
3162*993b0882SAndroid Build Coastguard Worker   LocaleList locale_list = LocaleList::ParseFrom(locales);
3163*993b0882SAndroid Build Coastguard Worker   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3164*993b0882SAndroid Build Coastguard Worker       datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3165*993b0882SAndroid Build Coastguard Worker                               reference_timezone, locale_list, mode,
3166*993b0882SAndroid Build Coastguard Worker                               annotation_usecase,
3167*993b0882SAndroid Build Coastguard Worker                               /*anchor_start_end=*/false);
3168*993b0882SAndroid Build Coastguard Worker   if (!result_status.ok()) {
3169*993b0882SAndroid Build Coastguard Worker     return false;
3170*993b0882SAndroid Build Coastguard Worker   }
3171*993b0882SAndroid Build Coastguard Worker 
3172*993b0882SAndroid Build Coastguard Worker   for (const DatetimeParseResultSpan& datetime_span :
3173*993b0882SAndroid Build Coastguard Worker        result_status.ValueOrDie()) {
3174*993b0882SAndroid Build Coastguard Worker     AnnotatedSpan annotated_span;
3175*993b0882SAndroid Build Coastguard Worker     annotated_span.span = datetime_span.span;
3176*993b0882SAndroid Build Coastguard Worker     for (const DatetimeParseResult& parse_result : datetime_span.data) {
3177*993b0882SAndroid Build Coastguard Worker       annotated_span.classification.emplace_back(
3178*993b0882SAndroid Build Coastguard Worker           PickCollectionForDatetime(parse_result),
3179*993b0882SAndroid Build Coastguard Worker           datetime_span.target_classification_score,
3180*993b0882SAndroid Build Coastguard Worker           datetime_span.priority_score);
3181*993b0882SAndroid Build Coastguard Worker       annotated_span.classification.back().datetime_parse_result = parse_result;
3182*993b0882SAndroid Build Coastguard Worker       if (is_serialized_entity_data_enabled) {
3183*993b0882SAndroid Build Coastguard Worker         annotated_span.classification.back().serialized_entity_data =
3184*993b0882SAndroid Build Coastguard Worker             CreateDatetimeSerializedEntityData(parse_result);
3185*993b0882SAndroid Build Coastguard Worker       }
3186*993b0882SAndroid Build Coastguard Worker     }
3187*993b0882SAndroid Build Coastguard Worker     annotated_span.source = AnnotatedSpan::Source::DATETIME;
3188*993b0882SAndroid Build Coastguard Worker     result->push_back(std::move(annotated_span));
3189*993b0882SAndroid Build Coastguard Worker   }
3190*993b0882SAndroid Build Coastguard Worker   return true;
3191*993b0882SAndroid Build Coastguard Worker }
3192*993b0882SAndroid Build Coastguard Worker 
model() const3193*993b0882SAndroid Build Coastguard Worker const Model* Annotator::model() const { return model_; }
entity_data_schema() const3194*993b0882SAndroid Build Coastguard Worker const reflection::Schema* Annotator::entity_data_schema() const {
3195*993b0882SAndroid Build Coastguard Worker   return entity_data_schema_;
3196*993b0882SAndroid Build Coastguard Worker }
3197*993b0882SAndroid Build Coastguard Worker 
ViewModel(const void * buffer,int size)3198*993b0882SAndroid Build Coastguard Worker const Model* ViewModel(const void* buffer, int size) {
3199*993b0882SAndroid Build Coastguard Worker   if (!buffer) {
3200*993b0882SAndroid Build Coastguard Worker     return nullptr;
3201*993b0882SAndroid Build Coastguard Worker   }
3202*993b0882SAndroid Build Coastguard Worker 
3203*993b0882SAndroid Build Coastguard Worker   return LoadAndVerifyModel(buffer, size);
3204*993b0882SAndroid Build Coastguard Worker }
3205*993b0882SAndroid Build Coastguard Worker 
LookUpKnowledgeEntity(const std::string & id) const3206*993b0882SAndroid Build Coastguard Worker StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3207*993b0882SAndroid Build Coastguard Worker     const std::string& id) const {
3208*993b0882SAndroid Build Coastguard Worker   if (!knowledge_engine_) {
3209*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::FAILED_PRECONDITION,
3210*993b0882SAndroid Build Coastguard Worker                   "knowledge_engine_ is nullptr");
3211*993b0882SAndroid Build Coastguard Worker   }
3212*993b0882SAndroid Build Coastguard Worker   return knowledge_engine_->LookUpEntity(id);
3213*993b0882SAndroid Build Coastguard Worker }
3214*993b0882SAndroid Build Coastguard Worker 
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3215*993b0882SAndroid Build Coastguard Worker StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3216*993b0882SAndroid Build Coastguard Worker     const std::string& mid_str, const std::string& property) const {
3217*993b0882SAndroid Build Coastguard Worker   if (!knowledge_engine_) {
3218*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::FAILED_PRECONDITION,
3219*993b0882SAndroid Build Coastguard Worker                   "knowledge_engine_ is nullptr");
3220*993b0882SAndroid Build Coastguard Worker   }
3221*993b0882SAndroid Build Coastguard Worker   return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3222*993b0882SAndroid Build Coastguard Worker }
3223*993b0882SAndroid Build Coastguard Worker 
3224*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
3225