1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/intents/intent-generator.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker #include <vector>
22*993b0882SAndroid Build Coastguard Worker
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/intents/jni-lua.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/java/jni-helper.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib.h"
28*993b0882SAndroid Build Coastguard Worker
29*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
30*993b0882SAndroid Build Coastguard Worker extern "C" {
31*993b0882SAndroid Build Coastguard Worker #endif
32*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
33*993b0882SAndroid Build Coastguard Worker #include "lua.h"
34*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
35*993b0882SAndroid Build Coastguard Worker }
36*993b0882SAndroid Build Coastguard Worker #endif
37*993b0882SAndroid Build Coastguard Worker
38*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
39*993b0882SAndroid Build Coastguard Worker namespace {
40*993b0882SAndroid Build Coastguard Worker
41*993b0882SAndroid Build Coastguard Worker static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
42*993b0882SAndroid Build Coastguard Worker static constexpr const char* kEnableAddContactIntent =
43*993b0882SAndroid Build Coastguard Worker "enable_add_contact_intent";
44*993b0882SAndroid Build Coastguard Worker static constexpr const char* kEnableSearchIntent = "enable_search_intent";
45*993b0882SAndroid Build Coastguard Worker
46*993b0882SAndroid Build Coastguard Worker // Lua environment for classfication result intent generation.
47*993b0882SAndroid Build Coastguard Worker class AnnotatorJniEnvironment : public JniLuaEnvironment {
48*993b0882SAndroid Build Coastguard Worker public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent)49*993b0882SAndroid Build Coastguard Worker AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
50*993b0882SAndroid Build Coastguard Worker const jobject context,
51*993b0882SAndroid Build Coastguard Worker const std::vector<Locale>& device_locales,
52*993b0882SAndroid Build Coastguard Worker const std::string& entity_text,
53*993b0882SAndroid Build Coastguard Worker const ClassificationResult& classification,
54*993b0882SAndroid Build Coastguard Worker const int64 reference_time_ms_utc,
55*993b0882SAndroid Build Coastguard Worker const reflection::Schema* entity_data_schema,
56*993b0882SAndroid Build Coastguard Worker const bool enable_add_contact_intent,
57*993b0882SAndroid Build Coastguard Worker const bool enable_search_intent)
58*993b0882SAndroid Build Coastguard Worker : JniLuaEnvironment(resources, jni_cache, context, device_locales),
59*993b0882SAndroid Build Coastguard Worker entity_text_(entity_text),
60*993b0882SAndroid Build Coastguard Worker classification_(classification),
61*993b0882SAndroid Build Coastguard Worker reference_time_ms_utc_(reference_time_ms_utc),
62*993b0882SAndroid Build Coastguard Worker enable_add_contact_intent_(enable_add_contact_intent),
63*993b0882SAndroid Build Coastguard Worker enable_search_intent_(enable_search_intent),
64*993b0882SAndroid Build Coastguard Worker entity_data_schema_(entity_data_schema) {}
65*993b0882SAndroid Build Coastguard Worker
66*993b0882SAndroid Build Coastguard Worker protected:
SetupExternalHook()67*993b0882SAndroid Build Coastguard Worker void SetupExternalHook() override {
68*993b0882SAndroid Build Coastguard Worker JniLuaEnvironment::SetupExternalHook();
69*993b0882SAndroid Build Coastguard Worker lua_pushinteger(state_, reference_time_ms_utc_);
70*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
71*993b0882SAndroid Build Coastguard Worker
72*993b0882SAndroid Build Coastguard Worker PushAnnotation(classification_, entity_text_, entity_data_schema_);
73*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "entity");
74*993b0882SAndroid Build Coastguard Worker
75*993b0882SAndroid Build Coastguard Worker lua_pushboolean(state_, enable_add_contact_intent_);
76*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent);
77*993b0882SAndroid Build Coastguard Worker
78*993b0882SAndroid Build Coastguard Worker lua_pushboolean(state_, enable_search_intent_);
79*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent);
80*993b0882SAndroid Build Coastguard Worker }
81*993b0882SAndroid Build Coastguard Worker
82*993b0882SAndroid Build Coastguard Worker const std::string& entity_text_;
83*993b0882SAndroid Build Coastguard Worker const ClassificationResult& classification_;
84*993b0882SAndroid Build Coastguard Worker const int64 reference_time_ms_utc_;
85*993b0882SAndroid Build Coastguard Worker const bool enable_add_contact_intent_;
86*993b0882SAndroid Build Coastguard Worker const bool enable_search_intent_;
87*993b0882SAndroid Build Coastguard Worker
88*993b0882SAndroid Build Coastguard Worker // Reflection schema data.
89*993b0882SAndroid Build Coastguard Worker const reflection::Schema* const entity_data_schema_;
90*993b0882SAndroid Build Coastguard Worker };
91*993b0882SAndroid Build Coastguard Worker
92*993b0882SAndroid Build Coastguard Worker // Lua environment for actions intent generation.
93*993b0882SAndroid Build Coastguard Worker class ActionsJniLuaEnvironment : public JniLuaEnvironment {
94*993b0882SAndroid Build Coastguard Worker public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)95*993b0882SAndroid Build Coastguard Worker ActionsJniLuaEnvironment(
96*993b0882SAndroid Build Coastguard Worker const Resources& resources, const JniCache* jni_cache,
97*993b0882SAndroid Build Coastguard Worker const jobject context, const std::vector<Locale>& device_locales,
98*993b0882SAndroid Build Coastguard Worker const Conversation& conversation, const ActionSuggestion& action,
99*993b0882SAndroid Build Coastguard Worker const reflection::Schema* actions_entity_data_schema,
100*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema)
101*993b0882SAndroid Build Coastguard Worker : JniLuaEnvironment(resources, jni_cache, context, device_locales),
102*993b0882SAndroid Build Coastguard Worker conversation_(conversation),
103*993b0882SAndroid Build Coastguard Worker action_(action),
104*993b0882SAndroid Build Coastguard Worker actions_entity_data_schema_(actions_entity_data_schema),
105*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_(annotations_entity_data_schema) {}
106*993b0882SAndroid Build Coastguard Worker
107*993b0882SAndroid Build Coastguard Worker protected:
SetupExternalHook()108*993b0882SAndroid Build Coastguard Worker void SetupExternalHook() override {
109*993b0882SAndroid Build Coastguard Worker JniLuaEnvironment::SetupExternalHook();
110*993b0882SAndroid Build Coastguard Worker PushConversation(&conversation_.messages, annotations_entity_data_schema_);
111*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "conversation");
112*993b0882SAndroid Build Coastguard Worker
113*993b0882SAndroid Build Coastguard Worker PushAction(action_, actions_entity_data_schema_,
114*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_);
115*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "entity");
116*993b0882SAndroid Build Coastguard Worker }
117*993b0882SAndroid Build Coastguard Worker
118*993b0882SAndroid Build Coastguard Worker const Conversation& conversation_;
119*993b0882SAndroid Build Coastguard Worker const ActionSuggestion& action_;
120*993b0882SAndroid Build Coastguard Worker const reflection::Schema* actions_entity_data_schema_;
121*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema_;
122*993b0882SAndroid Build Coastguard Worker };
123*993b0882SAndroid Build Coastguard Worker
124*993b0882SAndroid Build Coastguard Worker } // namespace
125*993b0882SAndroid Build Coastguard Worker
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)126*993b0882SAndroid Build Coastguard Worker std::unique_ptr<IntentGenerator> IntentGenerator::Create(
127*993b0882SAndroid Build Coastguard Worker const IntentFactoryModel* options, const ResourcePool* resources,
128*993b0882SAndroid Build Coastguard Worker const std::shared_ptr<JniCache>& jni_cache) {
129*993b0882SAndroid Build Coastguard Worker std::unique_ptr<IntentGenerator> intent_generator(
130*993b0882SAndroid Build Coastguard Worker new IntentGenerator(options, resources, jni_cache));
131*993b0882SAndroid Build Coastguard Worker
132*993b0882SAndroid Build Coastguard Worker if (options == nullptr || options->generator() == nullptr) {
133*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "No intent generator options.";
134*993b0882SAndroid Build Coastguard Worker return nullptr;
135*993b0882SAndroid Build Coastguard Worker }
136*993b0882SAndroid Build Coastguard Worker
137*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ZlibDecompressor> zlib_decompressor =
138*993b0882SAndroid Build Coastguard Worker ZlibDecompressor::Instance();
139*993b0882SAndroid Build Coastguard Worker if (!zlib_decompressor) {
140*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Cannot initialize decompressor.";
141*993b0882SAndroid Build Coastguard Worker return nullptr;
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker
144*993b0882SAndroid Build Coastguard Worker for (const IntentFactoryModel_::IntentGenerator* generator :
145*993b0882SAndroid Build Coastguard Worker *options->generator()) {
146*993b0882SAndroid Build Coastguard Worker std::string lua_template_generator;
147*993b0882SAndroid Build Coastguard Worker if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
148*993b0882SAndroid Build Coastguard Worker generator->lua_template_generator(),
149*993b0882SAndroid Build Coastguard Worker generator->compressed_lua_template_generator(),
150*993b0882SAndroid Build Coastguard Worker &lua_template_generator)) {
151*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not decompress generator template.";
152*993b0882SAndroid Build Coastguard Worker return nullptr;
153*993b0882SAndroid Build Coastguard Worker }
154*993b0882SAndroid Build Coastguard Worker
155*993b0882SAndroid Build Coastguard Worker std::string lua_code = lua_template_generator;
156*993b0882SAndroid Build Coastguard Worker if (options->precompile_generators()) {
157*993b0882SAndroid Build Coastguard Worker if (!Compile(lua_template_generator, &lua_code)) {
158*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not precompile generator template.";
159*993b0882SAndroid Build Coastguard Worker return nullptr;
160*993b0882SAndroid Build Coastguard Worker }
161*993b0882SAndroid Build Coastguard Worker }
162*993b0882SAndroid Build Coastguard Worker
163*993b0882SAndroid Build Coastguard Worker intent_generator->generators_[generator->type()->str()] = lua_code;
164*993b0882SAndroid Build Coastguard Worker }
165*993b0882SAndroid Build Coastguard Worker
166*993b0882SAndroid Build Coastguard Worker return intent_generator;
167*993b0882SAndroid Build Coastguard Worker }
168*993b0882SAndroid Build Coastguard Worker
ParseDeviceLocales(const jstring device_locales) const169*993b0882SAndroid Build Coastguard Worker std::vector<Locale> IntentGenerator::ParseDeviceLocales(
170*993b0882SAndroid Build Coastguard Worker const jstring device_locales) const {
171*993b0882SAndroid Build Coastguard Worker if (device_locales == nullptr) {
172*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "No locales provided.";
173*993b0882SAndroid Build Coastguard Worker return {};
174*993b0882SAndroid Build Coastguard Worker }
175*993b0882SAndroid Build Coastguard Worker StatusOr<std::string> status_or_locales_str =
176*993b0882SAndroid Build Coastguard Worker JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
177*993b0882SAndroid Build Coastguard Worker if (!status_or_locales_str.ok()) {
178*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR)
179*993b0882SAndroid Build Coastguard Worker << "JStringToUtf8String failed, cannot retrieve provided locales.";
180*993b0882SAndroid Build Coastguard Worker return {};
181*993b0882SAndroid Build Coastguard Worker }
182*993b0882SAndroid Build Coastguard Worker std::vector<Locale> locales;
183*993b0882SAndroid Build Coastguard Worker if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
184*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Cannot parse locales.";
185*993b0882SAndroid Build Coastguard Worker return {};
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker return locales;
188*993b0882SAndroid Build Coastguard Worker }
189*993b0882SAndroid Build Coastguard Worker
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent,std::vector<RemoteActionTemplate> * remote_actions) const190*993b0882SAndroid Build Coastguard Worker bool IntentGenerator::GenerateIntents(
191*993b0882SAndroid Build Coastguard Worker const jstring device_locales, const ClassificationResult& classification,
192*993b0882SAndroid Build Coastguard Worker const int64 reference_time_ms_utc, const std::string& text,
193*993b0882SAndroid Build Coastguard Worker const CodepointSpan selection_indices, const jobject context,
194*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema,
195*993b0882SAndroid Build Coastguard Worker const bool enable_add_contact_intent, const bool enable_search_intent,
196*993b0882SAndroid Build Coastguard Worker std::vector<RemoteActionTemplate>* remote_actions) const {
197*993b0882SAndroid Build Coastguard Worker if (options_ == nullptr) {
198*993b0882SAndroid Build Coastguard Worker return false;
199*993b0882SAndroid Build Coastguard Worker }
200*993b0882SAndroid Build Coastguard Worker
201*993b0882SAndroid Build Coastguard Worker // Retrieve generator for specified entity.
202*993b0882SAndroid Build Coastguard Worker auto it = generators_.find(classification.collection);
203*993b0882SAndroid Build Coastguard Worker if (it == generators_.end()) {
204*993b0882SAndroid Build Coastguard Worker TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
205*993b0882SAndroid Build Coastguard Worker return true;
206*993b0882SAndroid Build Coastguard Worker }
207*993b0882SAndroid Build Coastguard Worker
208*993b0882SAndroid Build Coastguard Worker const std::string entity_text =
209*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(text, /*do_copy=*/false)
210*993b0882SAndroid Build Coastguard Worker .UTF8Substring(selection_indices.first, selection_indices.second);
211*993b0882SAndroid Build Coastguard Worker
212*993b0882SAndroid Build Coastguard Worker std::unique_ptr<AnnotatorJniEnvironment> interpreter(
213*993b0882SAndroid Build Coastguard Worker new AnnotatorJniEnvironment(
214*993b0882SAndroid Build Coastguard Worker resources_, jni_cache_.get(), context,
215*993b0882SAndroid Build Coastguard Worker ParseDeviceLocales(device_locales), entity_text, classification,
216*993b0882SAndroid Build Coastguard Worker reference_time_ms_utc, annotations_entity_data_schema,
217*993b0882SAndroid Build Coastguard Worker enable_add_contact_intent, enable_search_intent));
218*993b0882SAndroid Build Coastguard Worker
219*993b0882SAndroid Build Coastguard Worker if (!interpreter->Initialize()) {
220*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not create Lua interpreter.";
221*993b0882SAndroid Build Coastguard Worker return false;
222*993b0882SAndroid Build Coastguard Worker }
223*993b0882SAndroid Build Coastguard Worker
224*993b0882SAndroid Build Coastguard Worker return interpreter->RunIntentGenerator(it->second, remote_actions);
225*993b0882SAndroid Build Coastguard Worker }
226*993b0882SAndroid Build Coastguard Worker
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const227*993b0882SAndroid Build Coastguard Worker bool IntentGenerator::GenerateIntents(
228*993b0882SAndroid Build Coastguard Worker const jstring device_locales, const ActionSuggestion& action,
229*993b0882SAndroid Build Coastguard Worker const Conversation& conversation, const jobject context,
230*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema,
231*993b0882SAndroid Build Coastguard Worker const reflection::Schema* actions_entity_data_schema,
232*993b0882SAndroid Build Coastguard Worker std::vector<RemoteActionTemplate>* remote_actions) const {
233*993b0882SAndroid Build Coastguard Worker if (options_ == nullptr) {
234*993b0882SAndroid Build Coastguard Worker return false;
235*993b0882SAndroid Build Coastguard Worker }
236*993b0882SAndroid Build Coastguard Worker
237*993b0882SAndroid Build Coastguard Worker // Retrieve generator for specified action.
238*993b0882SAndroid Build Coastguard Worker auto it = generators_.find(action.type);
239*993b0882SAndroid Build Coastguard Worker if (it == generators_.end()) {
240*993b0882SAndroid Build Coastguard Worker return true;
241*993b0882SAndroid Build Coastguard Worker }
242*993b0882SAndroid Build Coastguard Worker
243*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
244*993b0882SAndroid Build Coastguard Worker new ActionsJniLuaEnvironment(
245*993b0882SAndroid Build Coastguard Worker resources_, jni_cache_.get(), context,
246*993b0882SAndroid Build Coastguard Worker ParseDeviceLocales(device_locales), conversation, action,
247*993b0882SAndroid Build Coastguard Worker actions_entity_data_schema, annotations_entity_data_schema));
248*993b0882SAndroid Build Coastguard Worker
249*993b0882SAndroid Build Coastguard Worker if (!interpreter->Initialize()) {
250*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not create Lua interpreter.";
251*993b0882SAndroid Build Coastguard Worker return false;
252*993b0882SAndroid Build Coastguard Worker }
253*993b0882SAndroid Build Coastguard Worker
254*993b0882SAndroid Build Coastguard Worker return interpreter->RunIntentGenerator(it->second, remote_actions);
255*993b0882SAndroid Build Coastguard Worker }
256*993b0882SAndroid Build Coastguard Worker
257*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
258