xref: /aosp_15_r20/external/libtextclassifier/native/actions/ranker_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/ranker.h"
18 
19 #include <string>
20 
21 #include "actions/actions_model_generated.h"
22 #include "actions/types.h"
23 #include "utils/zlib/zlib.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 
27 namespace libtextclassifier3 {
28 namespace {
29 
30 MATCHER_P3(IsAction, type, response_text, score, "") {
31   return testing::Value(arg.type, type) &&
32          testing::Value(arg.response_text, response_text) &&
33          testing::Value(arg.score, score);
34 }
35 
36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37 
TEST(RankingTest,DeduplicationSmartReply)38 TEST(RankingTest, DeduplicationSmartReply) {
39   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
40   ActionsSuggestionsResponse response;
41   response.actions = {
42       {/*response_text=*/"hello there", /*type=*/"text_reply",
43        /*score=*/1.0},
44       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
45 
46   RankingOptionsT options;
47   options.deduplicate_suggestions = true;
48   flatbuffers::FlatBufferBuilder builder;
49   builder.Finish(RankingOptions::Pack(builder, &options));
50   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
51       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
52       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
53 
54   ranker->RankActions(conversation, &response);
55   EXPECT_THAT(
56       response.actions,
57       testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
58 }
59 
TEST(RankingTest,DeduplicationExtraData)60 TEST(RankingTest, DeduplicationExtraData) {
61   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
62   ActionsSuggestionsResponse response;
63   response.actions = {
64       {/*response_text=*/"hello there", /*type=*/"text_reply",
65        /*score=*/1.0, /*priority_score=*/0.0},
66       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
67        /*priority_score=*/0.0},
68       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
69        /*priority_score=*/0.0,
70        /*annotations=*/{}, /*serialized_entity_data=*/"test"},
71   };
72 
73   RankingOptionsT options;
74   options.deduplicate_suggestions = true;
75   flatbuffers::FlatBufferBuilder builder;
76   builder.Finish(RankingOptions::Pack(builder, &options));
77   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
78       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
79       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
80 
81   ranker->RankActions(conversation, &response);
82   EXPECT_THAT(
83       response.actions,
84       testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
85                                  // Is kept as it has different entity data.
86                                  IsAction("text_reply", "hello there", 0.6)}));
87 }
88 
TEST(RankingTest,DeduplicationAnnotations)89 TEST(RankingTest, DeduplicationAnnotations) {
90   const Conversation conversation = {
91       {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
92   ActionsSuggestionsResponse response;
93   {
94     ActionSuggestionAnnotation annotation;
95     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
96                        /*text=*/"742 Evergreen Terrace"};
97     annotation.entity = ClassificationResult("address", 0.5);
98     response.actions.push_back({/*response_text=*/"",
99                                 /*type=*/"view_map",
100                                 /*score=*/0.5,
101                                 /*priority_score=*/1.0,
102                                 /*annotations=*/{annotation}});
103   }
104   {
105     ActionSuggestionAnnotation annotation;
106     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
107                        /*text=*/"742 Evergreen Terrace"};
108     annotation.entity = ClassificationResult("address", 1.0);
109     response.actions.push_back({/*response_text=*/"",
110                                 /*type=*/"view_map",
111                                 /*score=*/1.0,
112                                 /*priority_score=*/2.0,
113                                 /*annotations=*/{annotation}});
114   }
115   {
116     ActionSuggestionAnnotation annotation;
117     annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
118                        /*text=*/"1-800-TESTING"};
119     annotation.entity = ClassificationResult("phone", 0.5);
120     response.actions.push_back({/*response_text=*/"",
121                                 /*type=*/"call_phone",
122                                 /*score=*/0.5,
123                                 /*priority_score=*/1.0,
124                                 /*annotations=*/{annotation}});
125   }
126 
127   RankingOptionsT options;
128   options.deduplicate_suggestions = true;
129   flatbuffers::FlatBufferBuilder builder;
130   builder.Finish(RankingOptions::Pack(builder, &options));
131   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
132       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
133       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
134 
135   ranker->RankActions(conversation, &response);
136   EXPECT_THAT(response.actions,
137               testing::ElementsAreArray({IsAction("view_map", "", 1.0),
138                                          IsAction("call_phone", "", 0.5)}));
139 }
140 
TEST(RankingTest,DeduplicationAnnotationsByPriorityScore)141 TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
142   const Conversation conversation = {
143       {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
144   ActionsSuggestionsResponse response;
145   {
146     ActionSuggestionAnnotation annotation;
147     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
148                        /*text=*/"742 Evergreen Terrace"};
149     annotation.entity = ClassificationResult("address", 0.5);
150     response.actions.push_back({/*response_text=*/"",
151                                 /*type=*/"view_map",
152                                 /*score=*/0.6,
153                                 /*priority_score=*/2.0,
154                                 /*annotations=*/{annotation}});
155   }
156   {
157     ActionSuggestionAnnotation annotation;
158     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
159                        /*text=*/"742 Evergreen Terrace"};
160     annotation.entity = ClassificationResult("address", 1.0);
161     response.actions.push_back({/*response_text=*/"",
162                                 /*type=*/"view_map",
163                                 /*score=*/1.0,
164                                 /*priority_score=*/1.0,
165                                 /*annotations=*/{annotation}});
166   }
167   {
168     ActionSuggestionAnnotation annotation;
169     annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
170                        /*text=*/"1-800-TESTING"};
171     annotation.entity = ClassificationResult("phone", 0.5);
172     response.actions.push_back({/*response_text=*/"",
173                                 /*type=*/"call_phone",
174                                 /*score=*/0.5,
175                                 /*priority_score=*/1.0,
176                                 /*annotations=*/{annotation}});
177   }
178 
179   RankingOptionsT options;
180   options.deduplicate_suggestions = true;
181   flatbuffers::FlatBufferBuilder builder;
182   builder.Finish(RankingOptions::Pack(builder, &options));
183   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
184       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
185       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
186 
187   ranker->RankActions(conversation, &response);
188   EXPECT_THAT(
189       response.actions,
190       testing::ElementsAreArray(
191           {IsAction("view_map", "",
192                     0.6),  // lower score wins, as priority score is higher
193            IsAction("call_phone", "", 0.5)}));
194 }
195 
TEST(RankingTest,DeduplicatesConflictingActions)196 TEST(RankingTest, DeduplicatesConflictingActions) {
197   const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
198   ActionsSuggestionsResponse response;
199   {
200     ActionSuggestionAnnotation annotation;
201     annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
202                        /*text=*/"911"};
203     annotation.entity = ClassificationResult("phone", 1.0);
204     response.actions.push_back({/*response_text=*/"",
205                                 /*type=*/"call_phone",
206                                 /*score=*/1.0,
207                                 /*priority_score=*/1.0,
208                                 /*annotations=*/{annotation}});
209   }
210   {
211     ActionSuggestionAnnotation annotation;
212     annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
213                        /*text=*/"A-911"};
214     annotation.entity = ClassificationResult("code", 1.0);
215     response.actions.push_back({/*response_text=*/"",
216                                 /*type=*/"copy_code",
217                                 /*score=*/1.0,
218                                 /*priority_score=*/2.0,
219                                 /*annotations=*/{annotation}});
220   }
221   RankingOptionsT options;
222   options.deduplicate_suggestions = true;
223   flatbuffers::FlatBufferBuilder builder;
224   builder.Finish(RankingOptions::Pack(builder, &options));
225   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
226       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
227       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
228 
229   ranker->RankActions(conversation, &response);
230   EXPECT_THAT(response.actions,
231               testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
232 }
233 
TEST(RankingTest,HandlesCompressedLuaScript)234 TEST(RankingTest, HandlesCompressedLuaScript) {
235   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
236   ActionsSuggestionsResponse response;
237   response.actions = {
238       {/*response_text=*/"hello there", /*type=*/"text_reply",
239        /*score=*/1.0},
240       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
241       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
242   const std::string test_snippet = R"(
243     local result = {}
244     for id, action in pairs(actions) do
245       if action.type ~= "text_reply" then
246         table.insert(result, id)
247       end
248     end
249     return result
250   )";
251   RankingOptionsT options;
252   options.compressed_lua_ranking_script.reset(new CompressedBufferT);
253   std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
254   compressor->Compress(test_snippet,
255                        options.compressed_lua_ranking_script.get());
256   options.deduplicate_suggestions = true;
257   flatbuffers::FlatBufferBuilder builder;
258   builder.Finish(RankingOptions::Pack(builder, &options));
259 
260   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
261   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
262       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
263       decompressor.get(), /*smart_reply_action_type=*/"text_reply");
264 
265   ranker->RankActions(conversation, &response);
266   EXPECT_THAT(response.actions,
267               testing::ElementsAreArray({IsActionType("share_location"),
268                                          IsActionType("add_to_collection")}));
269 }
270 
TEST(RankingTest,SuppressSmartRepliesWithAction)271 TEST(RankingTest, SuppressSmartRepliesWithAction) {
272   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
273   ActionsSuggestionsResponse response;
274   {
275     ActionSuggestionAnnotation annotation;
276     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
277                        /*text=*/"911"};
278     annotation.entity = ClassificationResult("phone", 1.0);
279     response.actions.push_back({/*response_text=*/"",
280                                 /*type=*/"call_phone",
281                                 /*score=*/1.0,
282                                 /*priority_score=*/1.0,
283                                 /*annotations=*/{annotation}});
284   }
285   response.actions.push_back({/*response_text=*/"How are you?",
286                               /*type=*/"text_reply"});
287   RankingOptionsT options;
288   options.suppress_smart_replies_with_actions = true;
289   flatbuffers::FlatBufferBuilder builder;
290   builder.Finish(RankingOptions::Pack(builder, &options));
291   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
292       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
293       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
294 
295   ranker->RankActions(conversation, &response);
296 
297   EXPECT_THAT(response.actions,
298               testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
299 }
300 
TEST(RankingTest,GroupsActionsByAnnotations)301 TEST(RankingTest, GroupsActionsByAnnotations) {
302   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
303   ActionsSuggestionsResponse response;
304   {
305     ActionSuggestionAnnotation annotation;
306     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
307                        /*text=*/"911"};
308     annotation.entity = ClassificationResult("phone", 1.0);
309     response.actions.push_back({/*response_text=*/"",
310                                 /*type=*/"call_phone",
311                                 /*score=*/1.0,
312                                 /*priority_score=*/0.0,
313                                 /*annotations=*/{annotation}});
314     response.actions.push_back({/*response_text=*/"",
315                                 /*type=*/"add_contact",
316                                 /*score=*/0.0,
317                                 /*priority_score=*/1.0,
318                                 /*annotations=*/{annotation}});
319   }
320   response.actions.push_back({/*response_text=*/"How are you?",
321                               /*type=*/"text_reply",
322                               /*score=*/0.5});
323   RankingOptionsT options;
324   options.group_by_annotations = true;
325   flatbuffers::FlatBufferBuilder builder;
326   builder.Finish(RankingOptions::Pack(builder, &options));
327   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
328       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
329       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
330 
331   ranker->RankActions(conversation, &response);
332 
333   // The text reply should be last, even though it has a higher score than the
334   // `add_contact` action.
335   EXPECT_THAT(
336       response.actions,
337       testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
338                                  IsAction("add_contact", "", 0.0),
339                                  IsAction("text_reply", "How are you?", 0.5)}));
340 }
341 
TEST(RankingTest,GroupsByAnnotationsSortedByPriority)342 TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
343   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
344   ActionsSuggestionsResponse response;
345   response.actions.push_back({/*response_text=*/"How are you?",
346                               /*type=*/"text_reply",
347                               /*score=*/2.0,
348                               /*priority_score=*/0.0});
349   {
350     ActionSuggestionAnnotation annotation;
351     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
352                        /*text=*/"911"};
353     annotation.entity = ClassificationResult("phone", 1.0);
354     response.actions.push_back({/*response_text=*/"",
355                                 /*type=*/"add_contact",
356                                 /*score=*/0.0,
357                                 /*priority_score=*/1.0,
358                                 /*annotations=*/{annotation}});
359     response.actions.push_back({/*response_text=*/"",
360                                 /*type=*/"call_phone",
361                                 /*score=*/1.0,
362                                 /*priority_score=*/0.0,
363                                 /*annotations=*/{annotation}});
364     response.actions.push_back({/*response_text=*/"",
365                                 /*type=*/"add_contact2",
366                                 /*score=*/0.5,
367                                 /*priority_score=*/1.0,
368                                 /*annotations=*/{annotation}});
369   }
370   RankingOptionsT options;
371   options.group_by_annotations = true;
372   options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
373   flatbuffers::FlatBufferBuilder builder;
374   builder.Finish(RankingOptions::Pack(builder, &options));
375   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
376       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
377       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
378 
379   ranker->RankActions(conversation, &response);
380 
381   // The text reply should be last, even though it's score is higher than
382   // any other scores -- because it's priority_score is lower than the max
383   // of those with the 'phone' annotation
384   EXPECT_THAT(response.actions,
385               testing::ElementsAreArray({
386                   // Group 1 (Phone annotation)
387                   IsAction("add_contact2", "", 0.5),  // priority_score=1.0
388                   IsAction("add_contact", "", 0.0),   // priority_score=1.0
389                   IsAction("call_phone", "", 1.0),    // priority_score=0.0
390                   IsAction("text_reply", "How are you?", 2.0),  // Group 2
391               }));
392 }
393 
TEST(RankingTest,SortsActionsByScore)394 TEST(RankingTest, SortsActionsByScore) {
395   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
396   ActionsSuggestionsResponse response;
397   {
398     ActionSuggestionAnnotation annotation;
399     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
400                        /*text=*/"911"};
401     annotation.entity = ClassificationResult("phone", 1.0);
402     response.actions.push_back({/*response_text=*/"",
403                                 /*type=*/"call_phone",
404                                 /*score=*/1.0,
405                                 /*priority_score=*/0.0,
406                                 /*annotations=*/{annotation}});
407     response.actions.push_back({/*response_text=*/"",
408                                 /*type=*/"add_contact",
409                                 /*score=*/0.0,
410                                 /*priority_score=*/1.0,
411                                 /*annotations=*/{annotation}});
412   }
413   response.actions.push_back({/*response_text=*/"How are you?",
414                               /*type=*/"text_reply",
415                               /*score=*/0.5});
416   RankingOptionsT options;
417   // Don't group by annotation.
418   options.group_by_annotations = false;
419   flatbuffers::FlatBufferBuilder builder;
420   builder.Finish(RankingOptions::Pack(builder, &options));
421   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
422       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
423       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
424 
425   ranker->RankActions(conversation, &response);
426 
427   EXPECT_THAT(
428       response.actions,
429       testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
430                                  IsAction("text_reply", "How are you?", 0.5),
431                                  IsAction("add_contact", "", 0.0)}));
432 }
433 
TEST(RankingTest,SortsActionsByPriority)434 TEST(RankingTest, SortsActionsByPriority) {
435   const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
436   ActionsSuggestionsResponse response;
437   // emoji replies given higher priority_score
438   response.actions.push_back({/*response_text=*/"��",
439                               /*type=*/"text_reply",
440                               /*score=*/0.5,
441                               /*priority_score=*/1.0});
442   response.actions.push_back({/*response_text=*/"��",
443                               /*type=*/"text_reply",
444                               /*score=*/0.4,
445                               /*priority_score=*/1.0});
446   response.actions.push_back({/*response_text=*/"Yes",
447                               /*type=*/"text_reply",
448                               /*score=*/1.0,
449                               /*priority_score=*/0.0});
450   RankingOptionsT options;
451   // Don't group by annotation.
452   options.group_by_annotations = false;
453   options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
454   flatbuffers::FlatBufferBuilder builder;
455   builder.Finish(RankingOptions::Pack(builder, &options));
456   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
457       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
458       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
459 
460   ranker->RankActions(conversation, &response);
461 
462   EXPECT_THAT(response.actions, testing::ElementsAreArray(
463                                     {IsAction("text_reply", "��", 0.5),
464                                      IsAction("text_reply", "��", 0.4),
465                                      // Ranked last because of priority score
466                                      IsAction("text_reply", "Yes", 1.0)}));
467 }
468 
469 }  // namespace
470 }  // namespace libtextclassifier3
471