xref: /aosp_15_r20/external/libtextclassifier/native/actions/lua-actions_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-actions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <map>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "actions/test-utils.h"
23*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h"
25*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
26*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
29*993b0882SAndroid Build Coastguard Worker namespace {
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker using testing::ElementsAre;
32*993b0882SAndroid Build Coastguard Worker 
TEST(LuaActions,SimpleAction)33*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleAction) {
34*993b0882SAndroid Build Coastguard Worker   Conversation conversation;
35*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
36*993b0882SAndroid Build Coastguard Worker     return {{ type = "test_action" }}
37*993b0882SAndroid Build Coastguard Worker   )";
38*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
39*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
40*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
41*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
42*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
43*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
44*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/nullptr,
45*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
46*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
47*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker 
TEST(LuaActions,ConversationActions)50*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, ConversationActions) {
51*993b0882SAndroid Build Coastguard Worker   Conversation conversation;
52*993b0882SAndroid Build Coastguard Worker   conversation.messages.push_back({/*user_id=*/0, "hello there!"});
53*993b0882SAndroid Build Coastguard Worker   conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
54*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
55*993b0882SAndroid Build Coastguard Worker     local actions = {}
56*993b0882SAndroid Build Coastguard Worker     for i, message in pairs(messages) do
57*993b0882SAndroid Build Coastguard Worker       if i < #messages then
58*993b0882SAndroid Build Coastguard Worker         if message.text == "hello there!" and
59*993b0882SAndroid Build Coastguard Worker            messages[i+1].text == "general kenobi!" then
60*993b0882SAndroid Build Coastguard Worker            table.insert(actions, {
61*993b0882SAndroid Build Coastguard Worker              type = "text_reply",
62*993b0882SAndroid Build Coastguard Worker              response_text = "you are a bold one!"
63*993b0882SAndroid Build Coastguard Worker            })
64*993b0882SAndroid Build Coastguard Worker         end
65*993b0882SAndroid Build Coastguard Worker         if message.text == "i am the senate!" and
66*993b0882SAndroid Build Coastguard Worker            messages[i+1].text == "not yet!" then
67*993b0882SAndroid Build Coastguard Worker            table.insert(actions, {
68*993b0882SAndroid Build Coastguard Worker              type = "text_reply",
69*993b0882SAndroid Build Coastguard Worker              response_text = "it's treason then"
70*993b0882SAndroid Build Coastguard Worker            })
71*993b0882SAndroid Build Coastguard Worker         end
72*993b0882SAndroid Build Coastguard Worker       end
73*993b0882SAndroid Build Coastguard Worker     end
74*993b0882SAndroid Build Coastguard Worker     return actions;
75*993b0882SAndroid Build Coastguard Worker   )";
76*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
77*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
78*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
79*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
80*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
81*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
82*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/nullptr,
83*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
84*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
85*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
86*993b0882SAndroid Build Coastguard Worker }
87*993b0882SAndroid Build Coastguard Worker 
88*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleModelAction) {
89*993b0882SAndroid Build Coastguard Worker   Conversation conversation;
90*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
91*993b0882SAndroid Build Coastguard Worker     if #model.actions_scores == 0 then
92*993b0882SAndroid Build Coastguard Worker       return {{ type = "test_action" }}
93*993b0882SAndroid Build Coastguard Worker     end
94*993b0882SAndroid Build Coastguard Worker     return {}
95*993b0882SAndroid Build Coastguard Worker   )";
96*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
97*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
98*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
99*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
100*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
101*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
102*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/nullptr,
103*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
104*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
105*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
106*993b0882SAndroid Build Coastguard Worker }
107*993b0882SAndroid Build Coastguard Worker 
108*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleModelRepliesAction) {
109*993b0882SAndroid Build Coastguard Worker   Conversation conversation;
110*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
111*993b0882SAndroid Build Coastguard Worker     if #model.reply == 0 then
112*993b0882SAndroid Build Coastguard Worker       return {{ type = "test_action" }}
113*993b0882SAndroid Build Coastguard Worker     end
114*993b0882SAndroid Build Coastguard Worker     return {}
115*993b0882SAndroid Build Coastguard Worker   )";
116*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
117*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
118*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
119*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
120*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
121*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
122*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/nullptr,
123*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
124*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
125*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
126*993b0882SAndroid Build Coastguard Worker }
127*993b0882SAndroid Build Coastguard Worker 
128*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, AnnotationActions) {
129*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
130*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
131*993b0882SAndroid Build Coastguard Worker   annotation.classification = {ClassificationResult("address", 1.0)};
132*993b0882SAndroid Build Coastguard Worker   Conversation conversation = {{{/*user_id=*/1, "are you at home?",
133*993b0882SAndroid Build Coastguard Worker                                  /*reference_time_ms_utc=*/0,
134*993b0882SAndroid Build Coastguard Worker                                  /*reference_timezone=*/"Europe/Zurich",
135*993b0882SAndroid Build Coastguard Worker                                  /*annotations=*/{annotation},
136*993b0882SAndroid Build Coastguard Worker                                  /*locales=*/"en"}}};
137*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
138*993b0882SAndroid Build Coastguard Worker     local actions = {}
139*993b0882SAndroid Build Coastguard Worker     local last_message = messages[#messages]
140*993b0882SAndroid Build Coastguard Worker     for i, annotation in pairs(last_message.annotation) do
141*993b0882SAndroid Build Coastguard Worker       if #annotation.classification > 0 then
142*993b0882SAndroid Build Coastguard Worker         if annotation.classification[1].collection == "address" then
143*993b0882SAndroid Build Coastguard Worker            local text = string.sub(last_message.text,
144*993b0882SAndroid Build Coastguard Worker                             annotation.span["begin"] + 1,
145*993b0882SAndroid Build Coastguard Worker                             annotation.span["end"])
146*993b0882SAndroid Build Coastguard Worker            table.insert(actions, {
147*993b0882SAndroid Build Coastguard Worker              type = "text_reply",
148*993b0882SAndroid Build Coastguard Worker              response_text = "i am at " .. text,
149*993b0882SAndroid Build Coastguard Worker              annotation = {{
150*993b0882SAndroid Build Coastguard Worker                name = "location",
151*993b0882SAndroid Build Coastguard Worker                span = {
152*993b0882SAndroid Build Coastguard Worker                  text = text
153*993b0882SAndroid Build Coastguard Worker                },
154*993b0882SAndroid Build Coastguard Worker                entity = annotation.classification[1]
155*993b0882SAndroid Build Coastguard Worker              }},
156*993b0882SAndroid Build Coastguard Worker            })
157*993b0882SAndroid Build Coastguard Worker         end
158*993b0882SAndroid Build Coastguard Worker       end
159*993b0882SAndroid Build Coastguard Worker     end
160*993b0882SAndroid Build Coastguard Worker     return actions;
161*993b0882SAndroid Build Coastguard Worker   )";
162*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
163*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
164*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
165*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
166*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
167*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
168*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/nullptr,
169*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
170*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
171*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
172*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
173*993b0882SAndroid Build Coastguard Worker }
174*993b0882SAndroid Build Coastguard Worker 
175*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, EntityData) {
176*993b0882SAndroid Build Coastguard Worker   std::string test_schema = TestEntityDataSchema();
177*993b0882SAndroid Build Coastguard Worker   Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
178*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
179*993b0882SAndroid Build Coastguard Worker     return {{
180*993b0882SAndroid Build Coastguard Worker       type = "test",
181*993b0882SAndroid Build Coastguard Worker       entity = {
182*993b0882SAndroid Build Coastguard Worker         greeting = "hello",
183*993b0882SAndroid Build Coastguard Worker         location = "there",
184*993b0882SAndroid Build Coastguard Worker         person = "Kenobi",
185*993b0882SAndroid Build Coastguard Worker       },
186*993b0882SAndroid Build Coastguard Worker     }};
187*993b0882SAndroid Build Coastguard Worker   )";
188*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> actions;
189*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
190*993b0882SAndroid Build Coastguard Worker                   test_snippet, conversation,
191*993b0882SAndroid Build Coastguard Worker                   /*model_executor=*/nullptr,
192*993b0882SAndroid Build Coastguard Worker                   /*model_spec=*/nullptr,
193*993b0882SAndroid Build Coastguard Worker                   /*interpreter=*/nullptr,
194*993b0882SAndroid Build Coastguard Worker                   /*actions_entity_data_schema=*/
195*993b0882SAndroid Build Coastguard Worker                   flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
196*993b0882SAndroid Build Coastguard Worker                   /*annotations_entity_data_schema=*/nullptr)
197*993b0882SAndroid Build Coastguard Worker                   ->SuggestActions(&actions));
198*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(actions, testing::SizeIs(1));
199*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ("test", actions.front().type);
200*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* entity =
201*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
202*993b0882SAndroid Build Coastguard Worker           actions.front().serialized_entity_data.data()));
203*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
204*993b0882SAndroid Build Coastguard Worker             "hello");
205*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
206*993b0882SAndroid Build Coastguard Worker             "there");
207*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
208*993b0882SAndroid Build Coastguard Worker             "Kenobi");
209*993b0882SAndroid Build Coastguard Worker }
210*993b0882SAndroid Build Coastguard Worker 
211*993b0882SAndroid Build Coastguard Worker }  // namespace
212*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
213