xref: /aosp_15_r20/external/libtextclassifier/native/actions/lua-ranker.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-ranker.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
23*993b0882SAndroid Build Coastguard Worker extern "C" {
24*993b0882SAndroid Build Coastguard Worker #endif
25*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
26*993b0882SAndroid Build Coastguard Worker #include "lualib.h"
27*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
28*993b0882SAndroid Build Coastguard Worker }
29*993b0882SAndroid Build Coastguard Worker #endif
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
32*993b0882SAndroid Build Coastguard Worker 
33*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)34*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsLuaRanker::Create(
35*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const std::string& ranker_code,
36*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* entity_data_schema,
37*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema,
38*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response) {
39*993b0882SAndroid Build Coastguard Worker   auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
40*993b0882SAndroid Build Coastguard Worker       new ActionsSuggestionsLuaRanker(
41*993b0882SAndroid Build Coastguard Worker           conversation, ranker_code, entity_data_schema,
42*993b0882SAndroid Build Coastguard Worker           annotations_entity_data_schema, response));
43*993b0882SAndroid Build Coastguard Worker   if (!ranker->Initialize()) {
44*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
45*993b0882SAndroid Build Coastguard Worker     return nullptr;
46*993b0882SAndroid Build Coastguard Worker   }
47*993b0882SAndroid Build Coastguard Worker   return ranker;
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker 
Initialize()50*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsLuaRanker::Initialize() {
51*993b0882SAndroid Build Coastguard Worker   return RunProtected([this] {
52*993b0882SAndroid Build Coastguard Worker            LoadDefaultLibraries();
53*993b0882SAndroid Build Coastguard Worker 
54*993b0882SAndroid Build Coastguard Worker            // Expose generated actions.
55*993b0882SAndroid Build Coastguard Worker            PushActions(&response_->actions, actions_entity_data_schema_,
56*993b0882SAndroid Build Coastguard Worker                        annotations_entity_data_schema_);
57*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "actions");
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker            // Expose conversation message stream.
60*993b0882SAndroid Build Coastguard Worker            PushConversation(&conversation_.messages,
61*993b0882SAndroid Build Coastguard Worker                             annotations_entity_data_schema_);
62*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "messages");
63*993b0882SAndroid Build Coastguard Worker            return LUA_OK;
64*993b0882SAndroid Build Coastguard Worker          }) == LUA_OK;
65*993b0882SAndroid Build Coastguard Worker }
66*993b0882SAndroid Build Coastguard Worker 
ReadActionsRanking()67*993b0882SAndroid Build Coastguard Worker int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
68*993b0882SAndroid Build Coastguard Worker   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
69*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Expected actions table, got: "
70*993b0882SAndroid Build Coastguard Worker                    << lua_type(state_, /*idx=*/-1);
71*993b0882SAndroid Build Coastguard Worker     lua_pop(state_, 1);
72*993b0882SAndroid Build Coastguard Worker     lua_error(state_);
73*993b0882SAndroid Build Coastguard Worker     return LUA_ERRRUN;
74*993b0882SAndroid Build Coastguard Worker   }
75*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> ranked_actions;
76*993b0882SAndroid Build Coastguard Worker   lua_pushnil(state_);
77*993b0882SAndroid Build Coastguard Worker   while (Next(/*index=*/-2)) {
78*993b0882SAndroid Build Coastguard Worker     const int action_id = Read<int>(/*index=*/-1) - 1;
79*993b0882SAndroid Build Coastguard Worker     lua_pop(state_, 1);
80*993b0882SAndroid Build Coastguard Worker     if (action_id < 0 || action_id >= response_->actions.size()) {
81*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82*993b0882SAndroid Build Coastguard Worker       lua_error(state_);
83*993b0882SAndroid Build Coastguard Worker       return LUA_ERRRUN;
84*993b0882SAndroid Build Coastguard Worker     }
85*993b0882SAndroid Build Coastguard Worker     ranked_actions.push_back(response_->actions[action_id]);
86*993b0882SAndroid Build Coastguard Worker   }
87*993b0882SAndroid Build Coastguard Worker   lua_pop(state_, 1);
88*993b0882SAndroid Build Coastguard Worker   response_->actions = ranked_actions;
89*993b0882SAndroid Build Coastguard Worker   return LUA_OK;
90*993b0882SAndroid Build Coastguard Worker }
91*993b0882SAndroid Build Coastguard Worker 
RankActions()92*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsLuaRanker::RankActions() {
93*993b0882SAndroid Build Coastguard Worker   if (response_->actions.empty()) {
94*993b0882SAndroid Build Coastguard Worker     // Nothing to do.
95*993b0882SAndroid Build Coastguard Worker     return true;
96*993b0882SAndroid Build Coastguard Worker   }
97*993b0882SAndroid Build Coastguard Worker 
98*993b0882SAndroid Build Coastguard Worker   if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99*993b0882SAndroid Build Coastguard Worker                       /*name=*/nullptr) != LUA_OK) {
100*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101*993b0882SAndroid Build Coastguard Worker     return false;
102*993b0882SAndroid Build Coastguard Worker   }
103*993b0882SAndroid Build Coastguard Worker 
104*993b0882SAndroid Build Coastguard Worker   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not run ranking snippet.";
106*993b0882SAndroid Build Coastguard Worker     return false;
107*993b0882SAndroid Build Coastguard Worker   }
108*993b0882SAndroid Build Coastguard Worker 
109*993b0882SAndroid Build Coastguard Worker   if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110*993b0882SAndroid Build Coastguard Worker       LUA_OK) {
111*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not read lua result.";
112*993b0882SAndroid Build Coastguard Worker     return false;
113*993b0882SAndroid Build Coastguard Worker   }
114*993b0882SAndroid Build Coastguard Worker   return true;
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
118