xref: /aosp_15_r20/external/libtextclassifier/native/utils/intents/intent-generator.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/intents/intent-generator.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker #include <vector>
22*993b0882SAndroid Build Coastguard Worker 
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/intents/jni-lua.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/java/jni-helper.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib.h"
28*993b0882SAndroid Build Coastguard Worker 
29*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
30*993b0882SAndroid Build Coastguard Worker extern "C" {
31*993b0882SAndroid Build Coastguard Worker #endif
32*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
33*993b0882SAndroid Build Coastguard Worker #include "lua.h"
34*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
35*993b0882SAndroid Build Coastguard Worker }
36*993b0882SAndroid Build Coastguard Worker #endif
37*993b0882SAndroid Build Coastguard Worker 
38*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
39*993b0882SAndroid Build Coastguard Worker namespace {
40*993b0882SAndroid Build Coastguard Worker 
41*993b0882SAndroid Build Coastguard Worker static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
42*993b0882SAndroid Build Coastguard Worker static constexpr const char* kEnableAddContactIntent =
43*993b0882SAndroid Build Coastguard Worker     "enable_add_contact_intent";
44*993b0882SAndroid Build Coastguard Worker static constexpr const char* kEnableSearchIntent = "enable_search_intent";
45*993b0882SAndroid Build Coastguard Worker 
46*993b0882SAndroid Build Coastguard Worker // Lua environment for classfication result intent generation.
47*993b0882SAndroid Build Coastguard Worker class AnnotatorJniEnvironment : public JniLuaEnvironment {
48*993b0882SAndroid Build Coastguard Worker  public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent)49*993b0882SAndroid Build Coastguard Worker   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
50*993b0882SAndroid Build Coastguard Worker                           const jobject context,
51*993b0882SAndroid Build Coastguard Worker                           const std::vector<Locale>& device_locales,
52*993b0882SAndroid Build Coastguard Worker                           const std::string& entity_text,
53*993b0882SAndroid Build Coastguard Worker                           const ClassificationResult& classification,
54*993b0882SAndroid Build Coastguard Worker                           const int64 reference_time_ms_utc,
55*993b0882SAndroid Build Coastguard Worker                           const reflection::Schema* entity_data_schema,
56*993b0882SAndroid Build Coastguard Worker                           const bool enable_add_contact_intent,
57*993b0882SAndroid Build Coastguard Worker                           const bool enable_search_intent)
58*993b0882SAndroid Build Coastguard Worker       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
59*993b0882SAndroid Build Coastguard Worker         entity_text_(entity_text),
60*993b0882SAndroid Build Coastguard Worker         classification_(classification),
61*993b0882SAndroid Build Coastguard Worker         reference_time_ms_utc_(reference_time_ms_utc),
62*993b0882SAndroid Build Coastguard Worker         enable_add_contact_intent_(enable_add_contact_intent),
63*993b0882SAndroid Build Coastguard Worker         enable_search_intent_(enable_search_intent),
64*993b0882SAndroid Build Coastguard Worker         entity_data_schema_(entity_data_schema) {}
65*993b0882SAndroid Build Coastguard Worker 
66*993b0882SAndroid Build Coastguard Worker  protected:
SetupExternalHook()67*993b0882SAndroid Build Coastguard Worker   void SetupExternalHook() override {
68*993b0882SAndroid Build Coastguard Worker     JniLuaEnvironment::SetupExternalHook();
69*993b0882SAndroid Build Coastguard Worker     lua_pushinteger(state_, reference_time_ms_utc_);
70*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
71*993b0882SAndroid Build Coastguard Worker 
72*993b0882SAndroid Build Coastguard Worker     PushAnnotation(classification_, entity_text_, entity_data_schema_);
73*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, "entity");
74*993b0882SAndroid Build Coastguard Worker 
75*993b0882SAndroid Build Coastguard Worker     lua_pushboolean(state_, enable_add_contact_intent_);
76*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent);
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker     lua_pushboolean(state_, enable_search_intent_);
79*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent);
80*993b0882SAndroid Build Coastguard Worker   }
81*993b0882SAndroid Build Coastguard Worker 
82*993b0882SAndroid Build Coastguard Worker   const std::string& entity_text_;
83*993b0882SAndroid Build Coastguard Worker   const ClassificationResult& classification_;
84*993b0882SAndroid Build Coastguard Worker   const int64 reference_time_ms_utc_;
85*993b0882SAndroid Build Coastguard Worker   const bool enable_add_contact_intent_;
86*993b0882SAndroid Build Coastguard Worker   const bool enable_search_intent_;
87*993b0882SAndroid Build Coastguard Worker 
88*993b0882SAndroid Build Coastguard Worker   // Reflection schema data.
89*993b0882SAndroid Build Coastguard Worker   const reflection::Schema* const entity_data_schema_;
90*993b0882SAndroid Build Coastguard Worker };
91*993b0882SAndroid Build Coastguard Worker 
92*993b0882SAndroid Build Coastguard Worker // Lua environment for actions intent generation.
93*993b0882SAndroid Build Coastguard Worker class ActionsJniLuaEnvironment : public JniLuaEnvironment {
94*993b0882SAndroid Build Coastguard Worker  public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)95*993b0882SAndroid Build Coastguard Worker   ActionsJniLuaEnvironment(
96*993b0882SAndroid Build Coastguard Worker       const Resources& resources, const JniCache* jni_cache,
97*993b0882SAndroid Build Coastguard Worker       const jobject context, const std::vector<Locale>& device_locales,
98*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation, const ActionSuggestion& action,
99*993b0882SAndroid Build Coastguard Worker       const reflection::Schema* actions_entity_data_schema,
100*993b0882SAndroid Build Coastguard Worker       const reflection::Schema* annotations_entity_data_schema)
101*993b0882SAndroid Build Coastguard Worker       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
102*993b0882SAndroid Build Coastguard Worker         conversation_(conversation),
103*993b0882SAndroid Build Coastguard Worker         action_(action),
104*993b0882SAndroid Build Coastguard Worker         actions_entity_data_schema_(actions_entity_data_schema),
105*993b0882SAndroid Build Coastguard Worker         annotations_entity_data_schema_(annotations_entity_data_schema) {}
106*993b0882SAndroid Build Coastguard Worker 
107*993b0882SAndroid Build Coastguard Worker  protected:
SetupExternalHook()108*993b0882SAndroid Build Coastguard Worker   void SetupExternalHook() override {
109*993b0882SAndroid Build Coastguard Worker     JniLuaEnvironment::SetupExternalHook();
110*993b0882SAndroid Build Coastguard Worker     PushConversation(&conversation_.messages, annotations_entity_data_schema_);
111*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, "conversation");
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker     PushAction(action_, actions_entity_data_schema_,
114*993b0882SAndroid Build Coastguard Worker                annotations_entity_data_schema_);
115*993b0882SAndroid Build Coastguard Worker     lua_setfield(state_, /*idx=*/-2, "entity");
116*993b0882SAndroid Build Coastguard Worker   }
117*993b0882SAndroid Build Coastguard Worker 
118*993b0882SAndroid Build Coastguard Worker   const Conversation& conversation_;
119*993b0882SAndroid Build Coastguard Worker   const ActionSuggestion& action_;
120*993b0882SAndroid Build Coastguard Worker   const reflection::Schema* actions_entity_data_schema_;
121*993b0882SAndroid Build Coastguard Worker   const reflection::Schema* annotations_entity_data_schema_;
122*993b0882SAndroid Build Coastguard Worker };
123*993b0882SAndroid Build Coastguard Worker 
124*993b0882SAndroid Build Coastguard Worker }  // namespace
125*993b0882SAndroid Build Coastguard Worker 
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)126*993b0882SAndroid Build Coastguard Worker std::unique_ptr<IntentGenerator> IntentGenerator::Create(
127*993b0882SAndroid Build Coastguard Worker     const IntentFactoryModel* options, const ResourcePool* resources,
128*993b0882SAndroid Build Coastguard Worker     const std::shared_ptr<JniCache>& jni_cache) {
129*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<IntentGenerator> intent_generator(
130*993b0882SAndroid Build Coastguard Worker       new IntentGenerator(options, resources, jni_cache));
131*993b0882SAndroid Build Coastguard Worker 
132*993b0882SAndroid Build Coastguard Worker   if (options == nullptr || options->generator() == nullptr) {
133*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No intent generator options.";
134*993b0882SAndroid Build Coastguard Worker     return nullptr;
135*993b0882SAndroid Build Coastguard Worker   }
136*993b0882SAndroid Build Coastguard Worker 
137*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
138*993b0882SAndroid Build Coastguard Worker       ZlibDecompressor::Instance();
139*993b0882SAndroid Build Coastguard Worker   if (!zlib_decompressor) {
140*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
141*993b0882SAndroid Build Coastguard Worker     return nullptr;
142*993b0882SAndroid Build Coastguard Worker   }
143*993b0882SAndroid Build Coastguard Worker 
144*993b0882SAndroid Build Coastguard Worker   for (const IntentFactoryModel_::IntentGenerator* generator :
145*993b0882SAndroid Build Coastguard Worker        *options->generator()) {
146*993b0882SAndroid Build Coastguard Worker     std::string lua_template_generator;
147*993b0882SAndroid Build Coastguard Worker     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
148*993b0882SAndroid Build Coastguard Worker             generator->lua_template_generator(),
149*993b0882SAndroid Build Coastguard Worker             generator->compressed_lua_template_generator(),
150*993b0882SAndroid Build Coastguard Worker             &lua_template_generator)) {
151*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not decompress generator template.";
152*993b0882SAndroid Build Coastguard Worker       return nullptr;
153*993b0882SAndroid Build Coastguard Worker     }
154*993b0882SAndroid Build Coastguard Worker 
155*993b0882SAndroid Build Coastguard Worker     std::string lua_code = lua_template_generator;
156*993b0882SAndroid Build Coastguard Worker     if (options->precompile_generators()) {
157*993b0882SAndroid Build Coastguard Worker       if (!Compile(lua_template_generator, &lua_code)) {
158*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not precompile generator template.";
159*993b0882SAndroid Build Coastguard Worker         return nullptr;
160*993b0882SAndroid Build Coastguard Worker       }
161*993b0882SAndroid Build Coastguard Worker     }
162*993b0882SAndroid Build Coastguard Worker 
163*993b0882SAndroid Build Coastguard Worker     intent_generator->generators_[generator->type()->str()] = lua_code;
164*993b0882SAndroid Build Coastguard Worker   }
165*993b0882SAndroid Build Coastguard Worker 
166*993b0882SAndroid Build Coastguard Worker   return intent_generator;
167*993b0882SAndroid Build Coastguard Worker }
168*993b0882SAndroid Build Coastguard Worker 
ParseDeviceLocales(const jstring device_locales) const169*993b0882SAndroid Build Coastguard Worker std::vector<Locale> IntentGenerator::ParseDeviceLocales(
170*993b0882SAndroid Build Coastguard Worker     const jstring device_locales) const {
171*993b0882SAndroid Build Coastguard Worker   if (device_locales == nullptr) {
172*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No locales provided.";
173*993b0882SAndroid Build Coastguard Worker     return {};
174*993b0882SAndroid Build Coastguard Worker   }
175*993b0882SAndroid Build Coastguard Worker   StatusOr<std::string> status_or_locales_str =
176*993b0882SAndroid Build Coastguard Worker       JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
177*993b0882SAndroid Build Coastguard Worker   if (!status_or_locales_str.ok()) {
178*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR)
179*993b0882SAndroid Build Coastguard Worker         << "JStringToUtf8String failed, cannot retrieve provided locales.";
180*993b0882SAndroid Build Coastguard Worker     return {};
181*993b0882SAndroid Build Coastguard Worker   }
182*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> locales;
183*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
184*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Cannot parse locales.";
185*993b0882SAndroid Build Coastguard Worker     return {};
186*993b0882SAndroid Build Coastguard Worker   }
187*993b0882SAndroid Build Coastguard Worker   return locales;
188*993b0882SAndroid Build Coastguard Worker }
189*993b0882SAndroid Build Coastguard Worker 
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent,std::vector<RemoteActionTemplate> * remote_actions) const190*993b0882SAndroid Build Coastguard Worker bool IntentGenerator::GenerateIntents(
191*993b0882SAndroid Build Coastguard Worker     const jstring device_locales, const ClassificationResult& classification,
192*993b0882SAndroid Build Coastguard Worker     const int64 reference_time_ms_utc, const std::string& text,
193*993b0882SAndroid Build Coastguard Worker     const CodepointSpan selection_indices, const jobject context,
194*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema,
195*993b0882SAndroid Build Coastguard Worker     const bool enable_add_contact_intent, const bool enable_search_intent,
196*993b0882SAndroid Build Coastguard Worker     std::vector<RemoteActionTemplate>* remote_actions) const {
197*993b0882SAndroid Build Coastguard Worker   if (options_ == nullptr) {
198*993b0882SAndroid Build Coastguard Worker     return false;
199*993b0882SAndroid Build Coastguard Worker   }
200*993b0882SAndroid Build Coastguard Worker 
201*993b0882SAndroid Build Coastguard Worker   // Retrieve generator for specified entity.
202*993b0882SAndroid Build Coastguard Worker   auto it = generators_.find(classification.collection);
203*993b0882SAndroid Build Coastguard Worker   if (it == generators_.end()) {
204*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
205*993b0882SAndroid Build Coastguard Worker     return true;
206*993b0882SAndroid Build Coastguard Worker   }
207*993b0882SAndroid Build Coastguard Worker 
208*993b0882SAndroid Build Coastguard Worker   const std::string entity_text =
209*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(text, /*do_copy=*/false)
210*993b0882SAndroid Build Coastguard Worker           .UTF8Substring(selection_indices.first, selection_indices.second);
211*993b0882SAndroid Build Coastguard Worker 
212*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
213*993b0882SAndroid Build Coastguard Worker       new AnnotatorJniEnvironment(
214*993b0882SAndroid Build Coastguard Worker           resources_, jni_cache_.get(), context,
215*993b0882SAndroid Build Coastguard Worker           ParseDeviceLocales(device_locales), entity_text, classification,
216*993b0882SAndroid Build Coastguard Worker           reference_time_ms_utc, annotations_entity_data_schema,
217*993b0882SAndroid Build Coastguard Worker           enable_add_contact_intent, enable_search_intent));
218*993b0882SAndroid Build Coastguard Worker 
219*993b0882SAndroid Build Coastguard Worker   if (!interpreter->Initialize()) {
220*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
221*993b0882SAndroid Build Coastguard Worker     return false;
222*993b0882SAndroid Build Coastguard Worker   }
223*993b0882SAndroid Build Coastguard Worker 
224*993b0882SAndroid Build Coastguard Worker   return interpreter->RunIntentGenerator(it->second, remote_actions);
225*993b0882SAndroid Build Coastguard Worker }
226*993b0882SAndroid Build Coastguard Worker 
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const227*993b0882SAndroid Build Coastguard Worker bool IntentGenerator::GenerateIntents(
228*993b0882SAndroid Build Coastguard Worker     const jstring device_locales, const ActionSuggestion& action,
229*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const jobject context,
230*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema,
231*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* actions_entity_data_schema,
232*993b0882SAndroid Build Coastguard Worker     std::vector<RemoteActionTemplate>* remote_actions) const {
233*993b0882SAndroid Build Coastguard Worker   if (options_ == nullptr) {
234*993b0882SAndroid Build Coastguard Worker     return false;
235*993b0882SAndroid Build Coastguard Worker   }
236*993b0882SAndroid Build Coastguard Worker 
237*993b0882SAndroid Build Coastguard Worker   // Retrieve generator for specified action.
238*993b0882SAndroid Build Coastguard Worker   auto it = generators_.find(action.type);
239*993b0882SAndroid Build Coastguard Worker   if (it == generators_.end()) {
240*993b0882SAndroid Build Coastguard Worker     return true;
241*993b0882SAndroid Build Coastguard Worker   }
242*993b0882SAndroid Build Coastguard Worker 
243*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
244*993b0882SAndroid Build Coastguard Worker       new ActionsJniLuaEnvironment(
245*993b0882SAndroid Build Coastguard Worker           resources_, jni_cache_.get(), context,
246*993b0882SAndroid Build Coastguard Worker           ParseDeviceLocales(device_locales), conversation, action,
247*993b0882SAndroid Build Coastguard Worker           actions_entity_data_schema, annotations_entity_data_schema));
248*993b0882SAndroid Build Coastguard Worker 
249*993b0882SAndroid Build Coastguard Worker   if (!interpreter->Initialize()) {
250*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
251*993b0882SAndroid Build Coastguard Worker     return false;
252*993b0882SAndroid Build Coastguard Worker   }
253*993b0882SAndroid Build Coastguard Worker 
254*993b0882SAndroid Build Coastguard Worker   return interpreter->RunIntentGenerator(it->second, remote_actions);
255*993b0882SAndroid Build Coastguard Worker }
256*993b0882SAndroid Build Coastguard Worker 
257*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
258