1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-actions.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <map>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #include "actions/test-utils.h"
23*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h"
25*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
26*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
27*993b0882SAndroid Build Coastguard Worker
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
29*993b0882SAndroid Build Coastguard Worker namespace {
30*993b0882SAndroid Build Coastguard Worker
31*993b0882SAndroid Build Coastguard Worker using testing::ElementsAre;
32*993b0882SAndroid Build Coastguard Worker
TEST(LuaActions,SimpleAction)33*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleAction) {
34*993b0882SAndroid Build Coastguard Worker Conversation conversation;
35*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
36*993b0882SAndroid Build Coastguard Worker return {{ type = "test_action" }}
37*993b0882SAndroid Build Coastguard Worker )";
38*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
39*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
40*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
41*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
42*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
43*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
44*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/nullptr,
45*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
46*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
47*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker
TEST(LuaActions,ConversationActions)50*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, ConversationActions) {
51*993b0882SAndroid Build Coastguard Worker Conversation conversation;
52*993b0882SAndroid Build Coastguard Worker conversation.messages.push_back({/*user_id=*/0, "hello there!"});
53*993b0882SAndroid Build Coastguard Worker conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
54*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
55*993b0882SAndroid Build Coastguard Worker local actions = {}
56*993b0882SAndroid Build Coastguard Worker for i, message in pairs(messages) do
57*993b0882SAndroid Build Coastguard Worker if i < #messages then
58*993b0882SAndroid Build Coastguard Worker if message.text == "hello there!" and
59*993b0882SAndroid Build Coastguard Worker messages[i+1].text == "general kenobi!" then
60*993b0882SAndroid Build Coastguard Worker table.insert(actions, {
61*993b0882SAndroid Build Coastguard Worker type = "text_reply",
62*993b0882SAndroid Build Coastguard Worker response_text = "you are a bold one!"
63*993b0882SAndroid Build Coastguard Worker })
64*993b0882SAndroid Build Coastguard Worker end
65*993b0882SAndroid Build Coastguard Worker if message.text == "i am the senate!" and
66*993b0882SAndroid Build Coastguard Worker messages[i+1].text == "not yet!" then
67*993b0882SAndroid Build Coastguard Worker table.insert(actions, {
68*993b0882SAndroid Build Coastguard Worker type = "text_reply",
69*993b0882SAndroid Build Coastguard Worker response_text = "it's treason then"
70*993b0882SAndroid Build Coastguard Worker })
71*993b0882SAndroid Build Coastguard Worker end
72*993b0882SAndroid Build Coastguard Worker end
73*993b0882SAndroid Build Coastguard Worker end
74*993b0882SAndroid Build Coastguard Worker return actions;
75*993b0882SAndroid Build Coastguard Worker )";
76*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
77*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
78*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
79*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
80*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
81*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
82*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/nullptr,
83*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
84*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
85*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
86*993b0882SAndroid Build Coastguard Worker }
87*993b0882SAndroid Build Coastguard Worker
88*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleModelAction) {
89*993b0882SAndroid Build Coastguard Worker Conversation conversation;
90*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
91*993b0882SAndroid Build Coastguard Worker if #model.actions_scores == 0 then
92*993b0882SAndroid Build Coastguard Worker return {{ type = "test_action" }}
93*993b0882SAndroid Build Coastguard Worker end
94*993b0882SAndroid Build Coastguard Worker return {}
95*993b0882SAndroid Build Coastguard Worker )";
96*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
97*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
98*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
99*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
100*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
101*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
102*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/nullptr,
103*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
104*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
105*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
106*993b0882SAndroid Build Coastguard Worker }
107*993b0882SAndroid Build Coastguard Worker
108*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, SimpleModelRepliesAction) {
109*993b0882SAndroid Build Coastguard Worker Conversation conversation;
110*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
111*993b0882SAndroid Build Coastguard Worker if #model.reply == 0 then
112*993b0882SAndroid Build Coastguard Worker return {{ type = "test_action" }}
113*993b0882SAndroid Build Coastguard Worker end
114*993b0882SAndroid Build Coastguard Worker return {}
115*993b0882SAndroid Build Coastguard Worker )";
116*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
117*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
118*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
119*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
120*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
121*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
122*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/nullptr,
123*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
124*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
125*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
126*993b0882SAndroid Build Coastguard Worker }
127*993b0882SAndroid Build Coastguard Worker
128*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, AnnotationActions) {
129*993b0882SAndroid Build Coastguard Worker AnnotatedSpan annotation;
130*993b0882SAndroid Build Coastguard Worker annotation.span = {11, 15};
131*993b0882SAndroid Build Coastguard Worker annotation.classification = {ClassificationResult("address", 1.0)};
132*993b0882SAndroid Build Coastguard Worker Conversation conversation = {{{/*user_id=*/1, "are you at home?",
133*993b0882SAndroid Build Coastguard Worker /*reference_time_ms_utc=*/0,
134*993b0882SAndroid Build Coastguard Worker /*reference_timezone=*/"Europe/Zurich",
135*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation},
136*993b0882SAndroid Build Coastguard Worker /*locales=*/"en"}}};
137*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
138*993b0882SAndroid Build Coastguard Worker local actions = {}
139*993b0882SAndroid Build Coastguard Worker local last_message = messages[#messages]
140*993b0882SAndroid Build Coastguard Worker for i, annotation in pairs(last_message.annotation) do
141*993b0882SAndroid Build Coastguard Worker if #annotation.classification > 0 then
142*993b0882SAndroid Build Coastguard Worker if annotation.classification[1].collection == "address" then
143*993b0882SAndroid Build Coastguard Worker local text = string.sub(last_message.text,
144*993b0882SAndroid Build Coastguard Worker annotation.span["begin"] + 1,
145*993b0882SAndroid Build Coastguard Worker annotation.span["end"])
146*993b0882SAndroid Build Coastguard Worker table.insert(actions, {
147*993b0882SAndroid Build Coastguard Worker type = "text_reply",
148*993b0882SAndroid Build Coastguard Worker response_text = "i am at " .. text,
149*993b0882SAndroid Build Coastguard Worker annotation = {{
150*993b0882SAndroid Build Coastguard Worker name = "location",
151*993b0882SAndroid Build Coastguard Worker span = {
152*993b0882SAndroid Build Coastguard Worker text = text
153*993b0882SAndroid Build Coastguard Worker },
154*993b0882SAndroid Build Coastguard Worker entity = annotation.classification[1]
155*993b0882SAndroid Build Coastguard Worker }},
156*993b0882SAndroid Build Coastguard Worker })
157*993b0882SAndroid Build Coastguard Worker end
158*993b0882SAndroid Build Coastguard Worker end
159*993b0882SAndroid Build Coastguard Worker end
160*993b0882SAndroid Build Coastguard Worker return actions;
161*993b0882SAndroid Build Coastguard Worker )";
162*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
163*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
164*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
165*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
166*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
167*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
168*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/nullptr,
169*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
170*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
171*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
172*993b0882SAndroid Build Coastguard Worker EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
173*993b0882SAndroid Build Coastguard Worker }
174*993b0882SAndroid Build Coastguard Worker
175*993b0882SAndroid Build Coastguard Worker TEST(LuaActions, EntityData) {
176*993b0882SAndroid Build Coastguard Worker std::string test_schema = TestEntityDataSchema();
177*993b0882SAndroid Build Coastguard Worker Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
178*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
179*993b0882SAndroid Build Coastguard Worker return {{
180*993b0882SAndroid Build Coastguard Worker type = "test",
181*993b0882SAndroid Build Coastguard Worker entity = {
182*993b0882SAndroid Build Coastguard Worker greeting = "hello",
183*993b0882SAndroid Build Coastguard Worker location = "there",
184*993b0882SAndroid Build Coastguard Worker person = "Kenobi",
185*993b0882SAndroid Build Coastguard Worker },
186*993b0882SAndroid Build Coastguard Worker }};
187*993b0882SAndroid Build Coastguard Worker )";
188*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> actions;
189*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
190*993b0882SAndroid Build Coastguard Worker test_snippet, conversation,
191*993b0882SAndroid Build Coastguard Worker /*model_executor=*/nullptr,
192*993b0882SAndroid Build Coastguard Worker /*model_spec=*/nullptr,
193*993b0882SAndroid Build Coastguard Worker /*interpreter=*/nullptr,
194*993b0882SAndroid Build Coastguard Worker /*actions_entity_data_schema=*/
195*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
196*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr)
197*993b0882SAndroid Build Coastguard Worker ->SuggestActions(&actions));
198*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(actions, testing::SizeIs(1));
199*993b0882SAndroid Build Coastguard Worker EXPECT_EQ("test", actions.front().type);
200*993b0882SAndroid Build Coastguard Worker const flatbuffers::Table* entity =
201*993b0882SAndroid Build Coastguard Worker flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
202*993b0882SAndroid Build Coastguard Worker actions.front().serialized_entity_data.data()));
203*993b0882SAndroid Build Coastguard Worker EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
204*993b0882SAndroid Build Coastguard Worker "hello");
205*993b0882SAndroid Build Coastguard Worker EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
206*993b0882SAndroid Build Coastguard Worker "there");
207*993b0882SAndroid Build Coastguard Worker EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
208*993b0882SAndroid Build Coastguard Worker "Kenobi");
209*993b0882SAndroid Build Coastguard Worker }
210*993b0882SAndroid Build Coastguard Worker
211*993b0882SAndroid Build Coastguard Worker } // namespace
212*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
213