1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-ranker.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
23*993b0882SAndroid Build Coastguard Worker extern "C" {
24*993b0882SAndroid Build Coastguard Worker #endif
25*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
26*993b0882SAndroid Build Coastguard Worker #include "lualib.h"
27*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
28*993b0882SAndroid Build Coastguard Worker }
29*993b0882SAndroid Build Coastguard Worker #endif
30*993b0882SAndroid Build Coastguard Worker
31*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
32*993b0882SAndroid Build Coastguard Worker
33*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)34*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsLuaRanker::Create(
35*993b0882SAndroid Build Coastguard Worker const Conversation& conversation, const std::string& ranker_code,
36*993b0882SAndroid Build Coastguard Worker const reflection::Schema* entity_data_schema,
37*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema,
38*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse* response) {
39*993b0882SAndroid Build Coastguard Worker auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
40*993b0882SAndroid Build Coastguard Worker new ActionsSuggestionsLuaRanker(
41*993b0882SAndroid Build Coastguard Worker conversation, ranker_code, entity_data_schema,
42*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema, response));
43*993b0882SAndroid Build Coastguard Worker if (!ranker->Initialize()) {
44*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
45*993b0882SAndroid Build Coastguard Worker return nullptr;
46*993b0882SAndroid Build Coastguard Worker }
47*993b0882SAndroid Build Coastguard Worker return ranker;
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker
Initialize()50*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsLuaRanker::Initialize() {
51*993b0882SAndroid Build Coastguard Worker return RunProtected([this] {
52*993b0882SAndroid Build Coastguard Worker LoadDefaultLibraries();
53*993b0882SAndroid Build Coastguard Worker
54*993b0882SAndroid Build Coastguard Worker // Expose generated actions.
55*993b0882SAndroid Build Coastguard Worker PushActions(&response_->actions, actions_entity_data_schema_,
56*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_);
57*993b0882SAndroid Build Coastguard Worker lua_setglobal(state_, "actions");
58*993b0882SAndroid Build Coastguard Worker
59*993b0882SAndroid Build Coastguard Worker // Expose conversation message stream.
60*993b0882SAndroid Build Coastguard Worker PushConversation(&conversation_.messages,
61*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_);
62*993b0882SAndroid Build Coastguard Worker lua_setglobal(state_, "messages");
63*993b0882SAndroid Build Coastguard Worker return LUA_OK;
64*993b0882SAndroid Build Coastguard Worker }) == LUA_OK;
65*993b0882SAndroid Build Coastguard Worker }
66*993b0882SAndroid Build Coastguard Worker
ReadActionsRanking()67*993b0882SAndroid Build Coastguard Worker int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
68*993b0882SAndroid Build Coastguard Worker if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
69*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Expected actions table, got: "
70*993b0882SAndroid Build Coastguard Worker << lua_type(state_, /*idx=*/-1);
71*993b0882SAndroid Build Coastguard Worker lua_pop(state_, 1);
72*993b0882SAndroid Build Coastguard Worker lua_error(state_);
73*993b0882SAndroid Build Coastguard Worker return LUA_ERRRUN;
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> ranked_actions;
76*993b0882SAndroid Build Coastguard Worker lua_pushnil(state_);
77*993b0882SAndroid Build Coastguard Worker while (Next(/*index=*/-2)) {
78*993b0882SAndroid Build Coastguard Worker const int action_id = Read<int>(/*index=*/-1) - 1;
79*993b0882SAndroid Build Coastguard Worker lua_pop(state_, 1);
80*993b0882SAndroid Build Coastguard Worker if (action_id < 0 || action_id >= response_->actions.size()) {
81*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82*993b0882SAndroid Build Coastguard Worker lua_error(state_);
83*993b0882SAndroid Build Coastguard Worker return LUA_ERRRUN;
84*993b0882SAndroid Build Coastguard Worker }
85*993b0882SAndroid Build Coastguard Worker ranked_actions.push_back(response_->actions[action_id]);
86*993b0882SAndroid Build Coastguard Worker }
87*993b0882SAndroid Build Coastguard Worker lua_pop(state_, 1);
88*993b0882SAndroid Build Coastguard Worker response_->actions = ranked_actions;
89*993b0882SAndroid Build Coastguard Worker return LUA_OK;
90*993b0882SAndroid Build Coastguard Worker }
91*993b0882SAndroid Build Coastguard Worker
RankActions()92*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsLuaRanker::RankActions() {
93*993b0882SAndroid Build Coastguard Worker if (response_->actions.empty()) {
94*993b0882SAndroid Build Coastguard Worker // Nothing to do.
95*993b0882SAndroid Build Coastguard Worker return true;
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker
98*993b0882SAndroid Build Coastguard Worker if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99*993b0882SAndroid Build Coastguard Worker /*name=*/nullptr) != LUA_OK) {
100*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101*993b0882SAndroid Build Coastguard Worker return false;
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker
104*993b0882SAndroid Build Coastguard Worker if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not run ranking snippet.";
106*993b0882SAndroid Build Coastguard Worker return false;
107*993b0882SAndroid Build Coastguard Worker }
108*993b0882SAndroid Build Coastguard Worker
109*993b0882SAndroid Build Coastguard Worker if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110*993b0882SAndroid Build Coastguard Worker LUA_OK) {
111*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not read lua result.";
112*993b0882SAndroid Build Coastguard Worker return false;
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker return true;
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker
117*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
118