1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-ranker.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <string>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
23*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
24*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
25*993b0882SAndroid Build Coastguard Worker
26*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
27*993b0882SAndroid Build Coastguard Worker namespace {
28*993b0882SAndroid Build Coastguard Worker
29*993b0882SAndroid Build Coastguard Worker MATCHER_P2(IsAction, type, response_text, "") {
30*993b0882SAndroid Build Coastguard Worker return testing::Value(arg.type, type) &&
31*993b0882SAndroid Build Coastguard Worker testing::Value(arg.response_text, response_text);
32*993b0882SAndroid Build Coastguard Worker }
33*993b0882SAndroid Build Coastguard Worker
34*993b0882SAndroid Build Coastguard Worker MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
35*993b0882SAndroid Build Coastguard Worker
TestEntitySchema()36*993b0882SAndroid Build Coastguard Worker std::string TestEntitySchema() {
37*993b0882SAndroid Build Coastguard Worker // Create fake entity data schema meta data.
38*993b0882SAndroid Build Coastguard Worker // Cannot use object oriented API here as that is not available for the
39*993b0882SAndroid Build Coastguard Worker // reflection schema.
40*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder schema_builder;
41*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Field>> fields = {
42*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
43*993b0882SAndroid Build Coastguard Worker schema_builder,
44*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("test"),
45*993b0882SAndroid Build Coastguard Worker /*type=*/
46*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
47*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::String),
48*993b0882SAndroid Build Coastguard Worker /*id=*/0,
49*993b0882SAndroid Build Coastguard Worker /*offset=*/4)};
50*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Enum>> enums;
51*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Object>> objects = {
52*993b0882SAndroid Build Coastguard Worker reflection::CreateObject(
53*993b0882SAndroid Build Coastguard Worker schema_builder,
54*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("EntityData"),
55*993b0882SAndroid Build Coastguard Worker /*fields=*/
56*993b0882SAndroid Build Coastguard Worker schema_builder.CreateVectorOfSortedTables(&fields))};
57*993b0882SAndroid Build Coastguard Worker schema_builder.Finish(reflection::CreateSchema(
58*993b0882SAndroid Build Coastguard Worker schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
59*993b0882SAndroid Build Coastguard Worker schema_builder.CreateVectorOfSortedTables(&enums),
60*993b0882SAndroid Build Coastguard Worker /*(unused) file_ident=*/0,
61*993b0882SAndroid Build Coastguard Worker /*(unused) file_ext=*/0,
62*993b0882SAndroid Build Coastguard Worker /*root_table*/ objects[0]));
63*993b0882SAndroid Build Coastguard Worker return std::string(
64*993b0882SAndroid Build Coastguard Worker reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
65*993b0882SAndroid Build Coastguard Worker schema_builder.GetSize());
66*993b0882SAndroid Build Coastguard Worker }
67*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,PassThrough)68*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, PassThrough) {
69*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
70*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
71*993b0882SAndroid Build Coastguard Worker response.actions = {
72*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
73*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
74*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
75*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
76*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
77*993b0882SAndroid Build Coastguard Worker local result = {}
78*993b0882SAndroid Build Coastguard Worker for i=1,#actions do
79*993b0882SAndroid Build Coastguard Worker table.insert(result, i)
80*993b0882SAndroid Build Coastguard Worker end
81*993b0882SAndroid Build Coastguard Worker return result
82*993b0882SAndroid Build Coastguard Worker )";
83*993b0882SAndroid Build Coastguard Worker
84*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
85*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
86*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
87*993b0882SAndroid Build Coastguard Worker ->RankActions());
88*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
89*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("text_reply"),
90*993b0882SAndroid Build Coastguard Worker IsActionType("share_location"),
91*993b0882SAndroid Build Coastguard Worker IsActionType("add_to_collection")}));
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,Filtering)94*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, Filtering) {
95*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
96*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
97*993b0882SAndroid Build Coastguard Worker response.actions = {
98*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
99*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
100*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
101*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
102*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
103*993b0882SAndroid Build Coastguard Worker return {}
104*993b0882SAndroid Build Coastguard Worker )";
105*993b0882SAndroid Build Coastguard Worker
106*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
107*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
108*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
109*993b0882SAndroid Build Coastguard Worker ->RankActions());
110*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions, testing::IsEmpty());
111*993b0882SAndroid Build Coastguard Worker }
112*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,Duplication)113*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, Duplication) {
114*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
115*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
116*993b0882SAndroid Build Coastguard Worker response.actions = {
117*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
118*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
119*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
120*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
121*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
122*993b0882SAndroid Build Coastguard Worker local result = {}
123*993b0882SAndroid Build Coastguard Worker for i=1,#actions do
124*993b0882SAndroid Build Coastguard Worker table.insert(result, 1)
125*993b0882SAndroid Build Coastguard Worker end
126*993b0882SAndroid Build Coastguard Worker return result
127*993b0882SAndroid Build Coastguard Worker )";
128*993b0882SAndroid Build Coastguard Worker
129*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
130*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
131*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
132*993b0882SAndroid Build Coastguard Worker ->RankActions());
133*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
134*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("text_reply"),
135*993b0882SAndroid Build Coastguard Worker IsActionType("text_reply"),
136*993b0882SAndroid Build Coastguard Worker IsActionType("text_reply")}));
137*993b0882SAndroid Build Coastguard Worker }
138*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,SortByScore)139*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, SortByScore) {
140*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
141*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
142*993b0882SAndroid Build Coastguard Worker response.actions = {
143*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
144*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
145*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
146*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
147*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
148*993b0882SAndroid Build Coastguard Worker function testScoreSorter(a, b)
149*993b0882SAndroid Build Coastguard Worker return actions[a].score < actions[b].score
150*993b0882SAndroid Build Coastguard Worker end
151*993b0882SAndroid Build Coastguard Worker local result = {}
152*993b0882SAndroid Build Coastguard Worker for i=1,#actions do
153*993b0882SAndroid Build Coastguard Worker result[i] = i
154*993b0882SAndroid Build Coastguard Worker end
155*993b0882SAndroid Build Coastguard Worker table.sort(result, testScoreSorter)
156*993b0882SAndroid Build Coastguard Worker return result
157*993b0882SAndroid Build Coastguard Worker )";
158*993b0882SAndroid Build Coastguard Worker
159*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
160*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
161*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
162*993b0882SAndroid Build Coastguard Worker ->RankActions());
163*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
164*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("add_to_collection"),
165*993b0882SAndroid Build Coastguard Worker IsActionType("share_location"),
166*993b0882SAndroid Build Coastguard Worker IsActionType("text_reply")}));
167*993b0882SAndroid Build Coastguard Worker }
168*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,SuppressType)169*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, SuppressType) {
170*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
171*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
172*993b0882SAndroid Build Coastguard Worker response.actions = {
173*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
174*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
175*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
176*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
177*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
178*993b0882SAndroid Build Coastguard Worker local result = {}
179*993b0882SAndroid Build Coastguard Worker for id, action in pairs(actions) do
180*993b0882SAndroid Build Coastguard Worker if action.type ~= "text_reply" then
181*993b0882SAndroid Build Coastguard Worker table.insert(result, id)
182*993b0882SAndroid Build Coastguard Worker end
183*993b0882SAndroid Build Coastguard Worker end
184*993b0882SAndroid Build Coastguard Worker return result
185*993b0882SAndroid Build Coastguard Worker )";
186*993b0882SAndroid Build Coastguard Worker
187*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
188*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
189*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
190*993b0882SAndroid Build Coastguard Worker ->RankActions());
191*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
192*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("share_location"),
193*993b0882SAndroid Build Coastguard Worker IsActionType("add_to_collection")}));
194*993b0882SAndroid Build Coastguard Worker }
195*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,HandlesConversation)196*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, HandlesConversation) {
197*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
198*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
199*993b0882SAndroid Build Coastguard Worker response.actions = {
200*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
201*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
202*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
203*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
204*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
205*993b0882SAndroid Build Coastguard Worker local result = {}
206*993b0882SAndroid Build Coastguard Worker if messages[1].text ~= "hello hello" then
207*993b0882SAndroid Build Coastguard Worker return result
208*993b0882SAndroid Build Coastguard Worker end
209*993b0882SAndroid Build Coastguard Worker for id, action in pairs(actions) do
210*993b0882SAndroid Build Coastguard Worker if action.type ~= "text_reply" then
211*993b0882SAndroid Build Coastguard Worker table.insert(result, id)
212*993b0882SAndroid Build Coastguard Worker end
213*993b0882SAndroid Build Coastguard Worker end
214*993b0882SAndroid Build Coastguard Worker return result
215*993b0882SAndroid Build Coastguard Worker )";
216*993b0882SAndroid Build Coastguard Worker
217*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
218*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, /*entity_data_schema=*/nullptr,
219*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
220*993b0882SAndroid Build Coastguard Worker ->RankActions());
221*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
222*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("share_location"),
223*993b0882SAndroid Build Coastguard Worker IsActionType("add_to_collection")}));
224*993b0882SAndroid Build Coastguard Worker }
225*993b0882SAndroid Build Coastguard Worker
TEST(LuaRankingTest,HandlesEntityData)226*993b0882SAndroid Build Coastguard Worker TEST(LuaRankingTest, HandlesEntityData) {
227*993b0882SAndroid Build Coastguard Worker std::string serialized_schema = TestEntitySchema();
228*993b0882SAndroid Build Coastguard Worker const reflection::Schema* entity_data_schema =
229*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
230*993b0882SAndroid Build Coastguard Worker
231*993b0882SAndroid Build Coastguard Worker // Create test entity data.
232*993b0882SAndroid Build Coastguard Worker MutableFlatbufferBuilder builder(entity_data_schema);
233*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> buffer = builder.NewRoot();
234*993b0882SAndroid Build Coastguard Worker buffer->Set("test", "value_a");
235*993b0882SAndroid Build Coastguard Worker const std::string serialized_entity_data_a = buffer->Serialize();
236*993b0882SAndroid Build Coastguard Worker buffer->Set("test", "value_b");
237*993b0882SAndroid Build Coastguard Worker const std::string serialized_entity_data_b = buffer->Serialize();
238*993b0882SAndroid Build Coastguard Worker
239*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
240*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
241*993b0882SAndroid Build Coastguard Worker response.actions = {
242*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"test",
243*993b0882SAndroid Build Coastguard Worker /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
244*993b0882SAndroid Build Coastguard Worker /*serialized_entity_data=*/serialized_entity_data_a},
245*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"test",
246*993b0882SAndroid Build Coastguard Worker /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
247*993b0882SAndroid Build Coastguard Worker /*serialized_entity_data=*/serialized_entity_data_b},
248*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
249*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
250*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
251*993b0882SAndroid Build Coastguard Worker local result = {}
252*993b0882SAndroid Build Coastguard Worker for id, action in pairs(actions) do
253*993b0882SAndroid Build Coastguard Worker if action.type == "test" and action.test == "value_a" then
254*993b0882SAndroid Build Coastguard Worker table.insert(result, id)
255*993b0882SAndroid Build Coastguard Worker end
256*993b0882SAndroid Build Coastguard Worker end
257*993b0882SAndroid Build Coastguard Worker return result
258*993b0882SAndroid Build Coastguard Worker )";
259*993b0882SAndroid Build Coastguard Worker
260*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
261*993b0882SAndroid Build Coastguard Worker conversation, test_snippet, entity_data_schema,
262*993b0882SAndroid Build Coastguard Worker /*annotations_entity_data_schema=*/nullptr, &response)
263*993b0882SAndroid Build Coastguard Worker ->RankActions());
264*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
265*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("test")}));
266*993b0882SAndroid Build Coastguard Worker }
267*993b0882SAndroid Build Coastguard Worker
268*993b0882SAndroid Build Coastguard Worker } // namespace
269*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
270