1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/ranker.h"
18
19 #include <string>
20
21 #include "actions/actions_model_generated.h"
22 #include "actions/types.h"
23 #include "utils/zlib/zlib.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26
27 namespace libtextclassifier3 {
28 namespace {
29
30 MATCHER_P3(IsAction, type, response_text, score, "") {
31 return testing::Value(arg.type, type) &&
32 testing::Value(arg.response_text, response_text) &&
33 testing::Value(arg.score, score);
34 }
35
36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37
TEST(RankingTest,DeduplicationSmartReply)38 TEST(RankingTest, DeduplicationSmartReply) {
39 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
40 ActionsSuggestionsResponse response;
41 response.actions = {
42 {/*response_text=*/"hello there", /*type=*/"text_reply",
43 /*score=*/1.0},
44 {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
45
46 RankingOptionsT options;
47 options.deduplicate_suggestions = true;
48 flatbuffers::FlatBufferBuilder builder;
49 builder.Finish(RankingOptions::Pack(builder, &options));
50 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
51 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
52 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
53
54 ranker->RankActions(conversation, &response);
55 EXPECT_THAT(
56 response.actions,
57 testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
58 }
59
TEST(RankingTest,DeduplicationExtraData)60 TEST(RankingTest, DeduplicationExtraData) {
61 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
62 ActionsSuggestionsResponse response;
63 response.actions = {
64 {/*response_text=*/"hello there", /*type=*/"text_reply",
65 /*score=*/1.0, /*priority_score=*/0.0},
66 {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
67 /*priority_score=*/0.0},
68 {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
69 /*priority_score=*/0.0,
70 /*annotations=*/{}, /*serialized_entity_data=*/"test"},
71 };
72
73 RankingOptionsT options;
74 options.deduplicate_suggestions = true;
75 flatbuffers::FlatBufferBuilder builder;
76 builder.Finish(RankingOptions::Pack(builder, &options));
77 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
78 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
79 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
80
81 ranker->RankActions(conversation, &response);
82 EXPECT_THAT(
83 response.actions,
84 testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
85 // Is kept as it has different entity data.
86 IsAction("text_reply", "hello there", 0.6)}));
87 }
88
TEST(RankingTest,DeduplicationAnnotations)89 TEST(RankingTest, DeduplicationAnnotations) {
90 const Conversation conversation = {
91 {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
92 ActionsSuggestionsResponse response;
93 {
94 ActionSuggestionAnnotation annotation;
95 annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
96 /*text=*/"742 Evergreen Terrace"};
97 annotation.entity = ClassificationResult("address", 0.5);
98 response.actions.push_back({/*response_text=*/"",
99 /*type=*/"view_map",
100 /*score=*/0.5,
101 /*priority_score=*/1.0,
102 /*annotations=*/{annotation}});
103 }
104 {
105 ActionSuggestionAnnotation annotation;
106 annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
107 /*text=*/"742 Evergreen Terrace"};
108 annotation.entity = ClassificationResult("address", 1.0);
109 response.actions.push_back({/*response_text=*/"",
110 /*type=*/"view_map",
111 /*score=*/1.0,
112 /*priority_score=*/2.0,
113 /*annotations=*/{annotation}});
114 }
115 {
116 ActionSuggestionAnnotation annotation;
117 annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
118 /*text=*/"1-800-TESTING"};
119 annotation.entity = ClassificationResult("phone", 0.5);
120 response.actions.push_back({/*response_text=*/"",
121 /*type=*/"call_phone",
122 /*score=*/0.5,
123 /*priority_score=*/1.0,
124 /*annotations=*/{annotation}});
125 }
126
127 RankingOptionsT options;
128 options.deduplicate_suggestions = true;
129 flatbuffers::FlatBufferBuilder builder;
130 builder.Finish(RankingOptions::Pack(builder, &options));
131 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
132 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
133 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
134
135 ranker->RankActions(conversation, &response);
136 EXPECT_THAT(response.actions,
137 testing::ElementsAreArray({IsAction("view_map", "", 1.0),
138 IsAction("call_phone", "", 0.5)}));
139 }
140
TEST(RankingTest,DeduplicationAnnotationsByPriorityScore)141 TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
142 const Conversation conversation = {
143 {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
144 ActionsSuggestionsResponse response;
145 {
146 ActionSuggestionAnnotation annotation;
147 annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
148 /*text=*/"742 Evergreen Terrace"};
149 annotation.entity = ClassificationResult("address", 0.5);
150 response.actions.push_back({/*response_text=*/"",
151 /*type=*/"view_map",
152 /*score=*/0.6,
153 /*priority_score=*/2.0,
154 /*annotations=*/{annotation}});
155 }
156 {
157 ActionSuggestionAnnotation annotation;
158 annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
159 /*text=*/"742 Evergreen Terrace"};
160 annotation.entity = ClassificationResult("address", 1.0);
161 response.actions.push_back({/*response_text=*/"",
162 /*type=*/"view_map",
163 /*score=*/1.0,
164 /*priority_score=*/1.0,
165 /*annotations=*/{annotation}});
166 }
167 {
168 ActionSuggestionAnnotation annotation;
169 annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
170 /*text=*/"1-800-TESTING"};
171 annotation.entity = ClassificationResult("phone", 0.5);
172 response.actions.push_back({/*response_text=*/"",
173 /*type=*/"call_phone",
174 /*score=*/0.5,
175 /*priority_score=*/1.0,
176 /*annotations=*/{annotation}});
177 }
178
179 RankingOptionsT options;
180 options.deduplicate_suggestions = true;
181 flatbuffers::FlatBufferBuilder builder;
182 builder.Finish(RankingOptions::Pack(builder, &options));
183 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
184 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
185 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
186
187 ranker->RankActions(conversation, &response);
188 EXPECT_THAT(
189 response.actions,
190 testing::ElementsAreArray(
191 {IsAction("view_map", "",
192 0.6), // lower score wins, as priority score is higher
193 IsAction("call_phone", "", 0.5)}));
194 }
195
TEST(RankingTest,DeduplicatesConflictingActions)196 TEST(RankingTest, DeduplicatesConflictingActions) {
197 const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
198 ActionsSuggestionsResponse response;
199 {
200 ActionSuggestionAnnotation annotation;
201 annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
202 /*text=*/"911"};
203 annotation.entity = ClassificationResult("phone", 1.0);
204 response.actions.push_back({/*response_text=*/"",
205 /*type=*/"call_phone",
206 /*score=*/1.0,
207 /*priority_score=*/1.0,
208 /*annotations=*/{annotation}});
209 }
210 {
211 ActionSuggestionAnnotation annotation;
212 annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
213 /*text=*/"A-911"};
214 annotation.entity = ClassificationResult("code", 1.0);
215 response.actions.push_back({/*response_text=*/"",
216 /*type=*/"copy_code",
217 /*score=*/1.0,
218 /*priority_score=*/2.0,
219 /*annotations=*/{annotation}});
220 }
221 RankingOptionsT options;
222 options.deduplicate_suggestions = true;
223 flatbuffers::FlatBufferBuilder builder;
224 builder.Finish(RankingOptions::Pack(builder, &options));
225 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
226 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
227 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
228
229 ranker->RankActions(conversation, &response);
230 EXPECT_THAT(response.actions,
231 testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
232 }
233
TEST(RankingTest,HandlesCompressedLuaScript)234 TEST(RankingTest, HandlesCompressedLuaScript) {
235 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
236 ActionsSuggestionsResponse response;
237 response.actions = {
238 {/*response_text=*/"hello there", /*type=*/"text_reply",
239 /*score=*/1.0},
240 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
241 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
242 const std::string test_snippet = R"(
243 local result = {}
244 for id, action in pairs(actions) do
245 if action.type ~= "text_reply" then
246 table.insert(result, id)
247 end
248 end
249 return result
250 )";
251 RankingOptionsT options;
252 options.compressed_lua_ranking_script.reset(new CompressedBufferT);
253 std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
254 compressor->Compress(test_snippet,
255 options.compressed_lua_ranking_script.get());
256 options.deduplicate_suggestions = true;
257 flatbuffers::FlatBufferBuilder builder;
258 builder.Finish(RankingOptions::Pack(builder, &options));
259
260 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
261 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
262 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
263 decompressor.get(), /*smart_reply_action_type=*/"text_reply");
264
265 ranker->RankActions(conversation, &response);
266 EXPECT_THAT(response.actions,
267 testing::ElementsAreArray({IsActionType("share_location"),
268 IsActionType("add_to_collection")}));
269 }
270
TEST(RankingTest,SuppressSmartRepliesWithAction)271 TEST(RankingTest, SuppressSmartRepliesWithAction) {
272 const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
273 ActionsSuggestionsResponse response;
274 {
275 ActionSuggestionAnnotation annotation;
276 annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
277 /*text=*/"911"};
278 annotation.entity = ClassificationResult("phone", 1.0);
279 response.actions.push_back({/*response_text=*/"",
280 /*type=*/"call_phone",
281 /*score=*/1.0,
282 /*priority_score=*/1.0,
283 /*annotations=*/{annotation}});
284 }
285 response.actions.push_back({/*response_text=*/"How are you?",
286 /*type=*/"text_reply"});
287 RankingOptionsT options;
288 options.suppress_smart_replies_with_actions = true;
289 flatbuffers::FlatBufferBuilder builder;
290 builder.Finish(RankingOptions::Pack(builder, &options));
291 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
292 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
293 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
294
295 ranker->RankActions(conversation, &response);
296
297 EXPECT_THAT(response.actions,
298 testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
299 }
300
TEST(RankingTest,GroupsActionsByAnnotations)301 TEST(RankingTest, GroupsActionsByAnnotations) {
302 const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
303 ActionsSuggestionsResponse response;
304 {
305 ActionSuggestionAnnotation annotation;
306 annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
307 /*text=*/"911"};
308 annotation.entity = ClassificationResult("phone", 1.0);
309 response.actions.push_back({/*response_text=*/"",
310 /*type=*/"call_phone",
311 /*score=*/1.0,
312 /*priority_score=*/0.0,
313 /*annotations=*/{annotation}});
314 response.actions.push_back({/*response_text=*/"",
315 /*type=*/"add_contact",
316 /*score=*/0.0,
317 /*priority_score=*/1.0,
318 /*annotations=*/{annotation}});
319 }
320 response.actions.push_back({/*response_text=*/"How are you?",
321 /*type=*/"text_reply",
322 /*score=*/0.5});
323 RankingOptionsT options;
324 options.group_by_annotations = true;
325 flatbuffers::FlatBufferBuilder builder;
326 builder.Finish(RankingOptions::Pack(builder, &options));
327 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
328 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
329 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
330
331 ranker->RankActions(conversation, &response);
332
333 // The text reply should be last, even though it has a higher score than the
334 // `add_contact` action.
335 EXPECT_THAT(
336 response.actions,
337 testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
338 IsAction("add_contact", "", 0.0),
339 IsAction("text_reply", "How are you?", 0.5)}));
340 }
341
TEST(RankingTest,GroupsByAnnotationsSortedByPriority)342 TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
343 const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
344 ActionsSuggestionsResponse response;
345 response.actions.push_back({/*response_text=*/"How are you?",
346 /*type=*/"text_reply",
347 /*score=*/2.0,
348 /*priority_score=*/0.0});
349 {
350 ActionSuggestionAnnotation annotation;
351 annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
352 /*text=*/"911"};
353 annotation.entity = ClassificationResult("phone", 1.0);
354 response.actions.push_back({/*response_text=*/"",
355 /*type=*/"add_contact",
356 /*score=*/0.0,
357 /*priority_score=*/1.0,
358 /*annotations=*/{annotation}});
359 response.actions.push_back({/*response_text=*/"",
360 /*type=*/"call_phone",
361 /*score=*/1.0,
362 /*priority_score=*/0.0,
363 /*annotations=*/{annotation}});
364 response.actions.push_back({/*response_text=*/"",
365 /*type=*/"add_contact2",
366 /*score=*/0.5,
367 /*priority_score=*/1.0,
368 /*annotations=*/{annotation}});
369 }
370 RankingOptionsT options;
371 options.group_by_annotations = true;
372 options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
373 flatbuffers::FlatBufferBuilder builder;
374 builder.Finish(RankingOptions::Pack(builder, &options));
375 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
376 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
377 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
378
379 ranker->RankActions(conversation, &response);
380
381 // The text reply should be last, even though it's score is higher than
382 // any other scores -- because it's priority_score is lower than the max
383 // of those with the 'phone' annotation
384 EXPECT_THAT(response.actions,
385 testing::ElementsAreArray({
386 // Group 1 (Phone annotation)
387 IsAction("add_contact2", "", 0.5), // priority_score=1.0
388 IsAction("add_contact", "", 0.0), // priority_score=1.0
389 IsAction("call_phone", "", 1.0), // priority_score=0.0
390 IsAction("text_reply", "How are you?", 2.0), // Group 2
391 }));
392 }
393
TEST(RankingTest,SortsActionsByScore)394 TEST(RankingTest, SortsActionsByScore) {
395 const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
396 ActionsSuggestionsResponse response;
397 {
398 ActionSuggestionAnnotation annotation;
399 annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
400 /*text=*/"911"};
401 annotation.entity = ClassificationResult("phone", 1.0);
402 response.actions.push_back({/*response_text=*/"",
403 /*type=*/"call_phone",
404 /*score=*/1.0,
405 /*priority_score=*/0.0,
406 /*annotations=*/{annotation}});
407 response.actions.push_back({/*response_text=*/"",
408 /*type=*/"add_contact",
409 /*score=*/0.0,
410 /*priority_score=*/1.0,
411 /*annotations=*/{annotation}});
412 }
413 response.actions.push_back({/*response_text=*/"How are you?",
414 /*type=*/"text_reply",
415 /*score=*/0.5});
416 RankingOptionsT options;
417 // Don't group by annotation.
418 options.group_by_annotations = false;
419 flatbuffers::FlatBufferBuilder builder;
420 builder.Finish(RankingOptions::Pack(builder, &options));
421 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
422 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
423 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
424
425 ranker->RankActions(conversation, &response);
426
427 EXPECT_THAT(
428 response.actions,
429 testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
430 IsAction("text_reply", "How are you?", 0.5),
431 IsAction("add_contact", "", 0.0)}));
432 }
433
TEST(RankingTest,SortsActionsByPriority)434 TEST(RankingTest, SortsActionsByPriority) {
435 const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
436 ActionsSuggestionsResponse response;
437 // emoji replies given higher priority_score
438 response.actions.push_back({/*response_text=*/"",
439 /*type=*/"text_reply",
440 /*score=*/0.5,
441 /*priority_score=*/1.0});
442 response.actions.push_back({/*response_text=*/"",
443 /*type=*/"text_reply",
444 /*score=*/0.4,
445 /*priority_score=*/1.0});
446 response.actions.push_back({/*response_text=*/"Yes",
447 /*type=*/"text_reply",
448 /*score=*/1.0,
449 /*priority_score=*/0.0});
450 RankingOptionsT options;
451 // Don't group by annotation.
452 options.group_by_annotations = false;
453 options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
454 flatbuffers::FlatBufferBuilder builder;
455 builder.Finish(RankingOptions::Pack(builder, &options));
456 auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
457 flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
458 /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
459
460 ranker->RankActions(conversation, &response);
461
462 EXPECT_THAT(response.actions, testing::ElementsAreArray(
463 {IsAction("text_reply", "", 0.5),
464 IsAction("text_reply", "", 0.4),
465 // Ranked last because of priority score
466 IsAction("text_reply", "Yes", 1.0)}));
467 }
468
469 } // namespace
470 } // namespace libtextclassifier3
471