xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 
24 #include "actions/actions_model_generated.h"
25 #include "actions/test-utils.h"
26 #include "actions/zlib-utils.h"
27 #include "annotator/collections.h"
28 #include "annotator/types.h"
29 #include "utils/flatbuffers/flatbuffers.h"
30 #include "utils/flatbuffers/flatbuffers_generated.h"
31 #include "utils/flatbuffers/mutable.h"
32 #include "utils/grammar/utils/locale-shard-map.h"
33 #include "utils/grammar/utils/rules.h"
34 #include "utils/hash/farmhash.h"
35 #include "utils/jvm-test-utils.h"
36 #include "utils/test-data-test-utils.h"
37 #include "gmock/gmock.h"
38 #include "gtest/gtest.h"
39 #include "flatbuffers/flatbuffers.h"
40 #include "flatbuffers/reflection.h"
41 
42 namespace libtextclassifier3 {
43 namespace {
44 
45 using ::testing::ElementsAre;
46 using ::testing::FloatEq;
47 using ::testing::IsEmpty;
48 using ::testing::NotNull;
49 using ::testing::SizeIs;
50 
51 constexpr char kModelFileName[] = "actions_suggestions_test.model";
52 constexpr char kModelGrammarFileName[] =
53     "actions_suggestions_grammar_test.model";
54 constexpr char kMultiTaskTF2TestModelFileName[] =
55     "actions_suggestions_test.multi_task_tf2_test.model";
56 constexpr char kMultiTaskModelFileName[] =
57     "actions_suggestions_test.multi_task_9heads.model";
58 constexpr char kHashGramModelFileName[] =
59     "actions_suggestions_test.hashgram.model";
60 constexpr char kMultiTaskSrP13nModelFileName[] =
61     "actions_suggestions_test.multi_task_sr_p13n.model";
62 constexpr char kMultiTaskSrEmojiModelFileName[] =
63     "actions_suggestions_test.multi_task_sr_emoji.model";
64 constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
65     "actions_suggestions_test.multi_task_sr_emoji_concept.model";
66 constexpr char kSensitiveTFliteModelFileName[] =
67     "actions_suggestions_test.sensitive_tflite.model";
68 constexpr char kLiveRelayTFLiteModelFileName[] =
69     "actions_suggestions_test.live_relay.model";
70 
ReadFile(const std::string & file_name)71 std::string ReadFile(const std::string& file_name) {
72   std::ifstream file_stream(file_name);
73   return std::string(std::istreambuf_iterator<char>(file_stream), {});
74 }
75 
GetModelPath()76 std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
77 
78 class ActionsSuggestionsTest : public testing::Test {
79  protected:
ActionsSuggestionsTest()80   explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
LoadTestModel(const std::string model_file_name)81   std::unique_ptr<ActionsSuggestions> LoadTestModel(
82       const std::string model_file_name) {
83     return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
84                                         unilib_.get());
85   }
LoadHashGramTestModel()86   std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
87     return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
88                                         unilib_.get());
89   }
LoadMultiTaskTestModel()90   std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
91     return ActionsSuggestions::FromPath(
92         GetModelPath() + kMultiTaskModelFileName, unilib_.get());
93   }
94 
LoadMultiTaskSrP13nTestModel()95   std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
96     return ActionsSuggestions::FromPath(
97         GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
98   }
99   std::unique_ptr<UniLib> unilib_;
100 };
101 
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)102 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
103   EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
104 }
105 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidInput)106 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
107   std::unique_ptr<ActionsSuggestions> actions_suggestions =
108       LoadTestModel(kModelFileName);
109   const ActionsSuggestionsResponse response =
110       actions_suggestions->SuggestActions(
111           {{{/*user_id=*/1, "Where are you?\xf0\x9f",
112              /*reference_time_ms_utc=*/0,
113              /*reference_timezone=*/"Europe/Zurich",
114              /*annotations=*/{}, /*locales=*/"en"}}});
115   EXPECT_THAT(response.actions, IsEmpty());
116 }
117 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidUtf8)118 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
119   std::unique_ptr<ActionsSuggestions> actions_suggestions =
120       LoadTestModel(kModelFileName);
121 
122   const ActionsSuggestionsResponse response =
123       actions_suggestions->SuggestActions(
124           {{{/*user_id=*/1,
125              "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
126              /*reference_time_ms_utc=*/0,
127              /*reference_timezone=*/"Europe/Zurich",
128              /*annotations=*/{}, /*locales=*/"en"}}});
129   EXPECT_THAT(response.actions, IsEmpty());
130 }
131 
TEST_F(ActionsSuggestionsTest,SuggestsActions)132 TEST_F(ActionsSuggestionsTest, SuggestsActions) {
133   std::unique_ptr<ActionsSuggestions> actions_suggestions =
134       LoadTestModel(kModelFileName);
135   const ActionsSuggestionsResponse response =
136       actions_suggestions->SuggestActions(
137           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
138              /*reference_timezone=*/"Europe/Zurich",
139              /*annotations=*/{}, /*locales=*/"en"}}});
140   EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
141 }
142 
TEST_F(ActionsSuggestionsTest,SuggestsNoActionsForUnknownLocale)143 TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
144   std::unique_ptr<ActionsSuggestions> actions_suggestions =
145       LoadTestModel(kModelFileName);
146   const ActionsSuggestionsResponse response =
147       actions_suggestions->SuggestActions(
148           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
149              /*reference_timezone=*/"Europe/Zurich",
150              /*annotations=*/{}, /*locales=*/"zz"}}});
151   EXPECT_THAT(response.actions, testing::IsEmpty());
152 }
153 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotations)154 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
155   std::unique_ptr<ActionsSuggestions> actions_suggestions =
156       LoadTestModel(kModelFileName);
157   AnnotatedSpan annotation;
158   annotation.span = {11, 15};
159   annotation.classification = {ClassificationResult("address", 1.0)};
160   const ActionsSuggestionsResponse response =
161       actions_suggestions->SuggestActions(
162           {{{/*user_id=*/1, "are you at home?",
163              /*reference_time_ms_utc=*/0,
164              /*reference_timezone=*/"Europe/Zurich",
165              /*annotations=*/{annotation},
166              /*locales=*/"en"}}});
167   ASSERT_GE(response.actions.size(), 1);
168   EXPECT_EQ(response.actions.front().type, "view_map");
169   EXPECT_EQ(response.actions.front().score, 1.0);
170 }
171 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithEntityData)172 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
173   const std::string actions_model_string =
174       ReadFile(GetModelPath() + kModelFileName);
175   std::unique_ptr<ActionsModelT> actions_model =
176       UnPackActionsModel(actions_model_string.c_str());
177   SetTestEntityDataSchema(actions_model.get());
178 
179   // Set custom actions from annotations config.
180   actions_model->annotation_actions_spec->annotation_mapping.clear();
181   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
182       new AnnotationActionsSpec_::AnnotationMappingT);
183   AnnotationActionsSpec_::AnnotationMappingT* mapping =
184       actions_model->annotation_actions_spec->annotation_mapping.back().get();
185   mapping->annotation_collection = "address";
186   mapping->action.reset(new ActionSuggestionSpecT);
187   mapping->action->type = "save_location";
188   mapping->action->score = 1.0;
189   mapping->action->priority_score = 2.0;
190   mapping->entity_field.reset(new FlatbufferFieldPathT);
191   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
192   mapping->entity_field->field.back()->field_name = "location";
193 
194   flatbuffers::FlatBufferBuilder builder;
195   FinishActionsModelBuffer(builder,
196                            ActionsModel::Pack(builder, actions_model.get()));
197   std::unique_ptr<ActionsSuggestions> actions_suggestions =
198       ActionsSuggestions::FromUnownedBuffer(
199           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
200           builder.GetSize(), unilib_.get());
201 
202   AnnotatedSpan annotation;
203   annotation.span = {11, 15};
204   annotation.classification = {ClassificationResult("address", 1.0)};
205   const ActionsSuggestionsResponse response =
206       actions_suggestions->SuggestActions(
207           {{{/*user_id=*/1, "are you at home?",
208              /*reference_time_ms_utc=*/0,
209              /*reference_timezone=*/"Europe/Zurich",
210              /*annotations=*/{annotation},
211              /*locales=*/"en"}}});
212   ASSERT_GE(response.actions.size(), 1);
213   EXPECT_EQ(response.actions.front().type, "save_location");
214   EXPECT_EQ(response.actions.front().score, 1.0);
215 
216   // Check that the `location` entity field holds the text from the address
217   // annotation.
218   const flatbuffers::Table* entity =
219       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
220           response.actions.front().serialized_entity_data.data()));
221   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
222             "home");
223 }
224 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithNormalization)225 TEST_F(ActionsSuggestionsTest,
226        SuggestsActionsFromAnnotationsWithNormalization) {
227   const std::string actions_model_string =
228       ReadFile(GetModelPath() + kModelFileName);
229   std::unique_ptr<ActionsModelT> actions_model =
230       UnPackActionsModel(actions_model_string.c_str());
231   SetTestEntityDataSchema(actions_model.get());
232 
233   // Set custom actions from annotations config.
234   actions_model->annotation_actions_spec->annotation_mapping.clear();
235   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
236       new AnnotationActionsSpec_::AnnotationMappingT);
237   AnnotationActionsSpec_::AnnotationMappingT* mapping =
238       actions_model->annotation_actions_spec->annotation_mapping.back().get();
239   mapping->annotation_collection = "address";
240   mapping->action.reset(new ActionSuggestionSpecT);
241   mapping->action->type = "save_location";
242   mapping->action->score = 1.0;
243   mapping->action->priority_score = 2.0;
244   mapping->entity_field.reset(new FlatbufferFieldPathT);
245   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
246   mapping->entity_field->field.back()->field_name = "location";
247   mapping->normalization_options.reset(new NormalizationOptionsT);
248   mapping->normalization_options->codepointwise_normalization =
249       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
250 
251   flatbuffers::FlatBufferBuilder builder;
252   FinishActionsModelBuffer(builder,
253                            ActionsModel::Pack(builder, actions_model.get()));
254   std::unique_ptr<ActionsSuggestions> actions_suggestions =
255       ActionsSuggestions::FromUnownedBuffer(
256           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
257           builder.GetSize(), unilib_.get());
258 
259   AnnotatedSpan annotation;
260   annotation.span = {11, 15};
261   annotation.classification = {ClassificationResult("address", 1.0)};
262   const ActionsSuggestionsResponse response =
263       actions_suggestions->SuggestActions(
264           {{{/*user_id=*/1, "are you at home?",
265              /*reference_time_ms_utc=*/0,
266              /*reference_timezone=*/"Europe/Zurich",
267              /*annotations=*/{annotation},
268              /*locales=*/"en"}}});
269   ASSERT_GE(response.actions.size(), 1);
270   EXPECT_EQ(response.actions.front().type, "save_location");
271   EXPECT_EQ(response.actions.front().score, 1.0);
272 
273   // Check that the `location` entity field holds the normalized text of the
274   // annotation.
275   const flatbuffers::Table* entity =
276       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
277           response.actions.front().serialized_entity_data.data()));
278   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
279             "HOME");
280 }
281 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromDuplicatedAnnotations)282 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
283   std::unique_ptr<ActionsSuggestions> actions_suggestions =
284       LoadTestModel(kModelFileName);
285   AnnotatedSpan flight_annotation;
286   flight_annotation.span = {11, 15};
287   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
288   AnnotatedSpan flight_annotation2;
289   flight_annotation2.span = {35, 39};
290   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
291   AnnotatedSpan email_annotation;
292   email_annotation.span = {43, 56};
293   email_annotation.classification = {ClassificationResult("email", 2.0)};
294 
295   const ActionsSuggestionsResponse response =
296       actions_suggestions->SuggestActions(
297           {{{/*user_id=*/1,
298              "call me at LX38 or send message to LX38 or [email protected].",
299              /*reference_time_ms_utc=*/0,
300              /*reference_timezone=*/"Europe/Zurich",
301              /*annotations=*/
302              {flight_annotation, flight_annotation2, email_annotation},
303              /*locales=*/"en"}}});
304 
305   ASSERT_GE(response.actions.size(), 2);
306   EXPECT_EQ(response.actions[0].type, "track_flight");
307   EXPECT_EQ(response.actions[0].score, 3.0);
308   EXPECT_EQ(response.actions[1].type, "send_email");
309   EXPECT_EQ(response.actions[1].score, 2.0);
310 }
311 
TEST_F(ActionsSuggestionsTest,SuggestsActionsAnnotationsWithNoDeduplication)312 TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
313   const std::string actions_model_string =
314       ReadFile(GetModelPath() + kModelFileName);
315   std::unique_ptr<ActionsModelT> actions_model =
316       UnPackActionsModel(actions_model_string.c_str());
317   // Disable deduplication.
318   actions_model->annotation_actions_spec->deduplicate_annotations = false;
319   flatbuffers::FlatBufferBuilder builder;
320   FinishActionsModelBuffer(builder,
321                            ActionsModel::Pack(builder, actions_model.get()));
322   std::unique_ptr<ActionsSuggestions> actions_suggestions =
323       ActionsSuggestions::FromUnownedBuffer(
324           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
325           builder.GetSize(), unilib_.get());
326   AnnotatedSpan flight_annotation;
327   flight_annotation.span = {11, 15};
328   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
329   AnnotatedSpan flight_annotation2;
330   flight_annotation2.span = {35, 39};
331   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
332   AnnotatedSpan email_annotation;
333   email_annotation.span = {43, 56};
334   email_annotation.classification = {ClassificationResult("email", 2.0)};
335 
336   const ActionsSuggestionsResponse response =
337       actions_suggestions->SuggestActions(
338           {{{/*user_id=*/1,
339              "call me at LX38 or send message to LX38 or [email protected].",
340              /*reference_time_ms_utc=*/0,
341              /*reference_timezone=*/"Europe/Zurich",
342              /*annotations=*/
343              {flight_annotation, flight_annotation2, email_annotation},
344              /*locales=*/"en"}}});
345 
346   ASSERT_GE(response.actions.size(), 3);
347   EXPECT_EQ(response.actions[0].type, "track_flight");
348   EXPECT_EQ(response.actions[0].score, 3.0);
349   EXPECT_EQ(response.actions[1].type, "track_flight");
350   EXPECT_EQ(response.actions[1].score, 2.5);
351   EXPECT_EQ(response.actions[2].type, "send_email");
352   EXPECT_EQ(response.actions[2].score, 2.0);
353 }
354 
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)355 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
356     const std::function<void(ActionsModelT*)>& set_config_fn,
357     const UniLib* unilib = nullptr) {
358   const std::string actions_model_string =
359       ReadFile(GetModelPath() + kModelFileName);
360   std::unique_ptr<ActionsModelT> actions_model =
361       UnPackActionsModel(actions_model_string.c_str());
362 
363   // Set custom config.
364   set_config_fn(actions_model.get());
365 
366   // Disable smart reply for easier testing.
367   actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
368 
369   flatbuffers::FlatBufferBuilder builder;
370   FinishActionsModelBuffer(builder,
371                            ActionsModel::Pack(builder, actions_model.get()));
372   std::unique_ptr<ActionsSuggestions> actions_suggestions =
373       ActionsSuggestions::FromUnownedBuffer(
374           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
375           builder.GetSize(), unilib);
376 
377   AnnotatedSpan flight_annotation;
378   flight_annotation.span = {15, 19};
379   flight_annotation.classification = {ClassificationResult("flight", 2.0)};
380   AnnotatedSpan email_annotation;
381   email_annotation.span = {0, 16};
382   email_annotation.classification = {ClassificationResult("email", 1.0)};
383 
384   return actions_suggestions->SuggestActions(
385       {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
386          "[email protected]",
387          /*reference_time_ms_utc=*/0,
388          /*reference_timezone=*/"Europe/Zurich",
389          /*annotations=*/
390          {email_annotation},
391          /*locales=*/"en"},
392         {/*user_id=*/2,
393          "[email protected]",
394          /*reference_time_ms_utc=*/0,
395          /*reference_timezone=*/"Europe/Zurich",
396          /*annotations=*/
397          {email_annotation},
398          /*locales=*/"en"},
399         {/*user_id=*/1,
400          "[email protected]",
401          /*reference_time_ms_utc=*/0,
402          /*reference_timezone=*/"Europe/Zurich",
403          /*annotations=*/
404          {email_annotation},
405          /*locales=*/"en"},
406         {/*user_id=*/1,
407          "I am on flight LX38.",
408          /*reference_time_ms_utc=*/0,
409          /*reference_timezone=*/"Europe/Zurich",
410          /*annotations=*/
411          {flight_annotation},
412          /*locales=*/"en"}}});
413 }
414 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastMessage)415 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
416   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
417       [](ActionsModelT* actions_model) {
418         actions_model->annotation_actions_spec->include_local_user_messages =
419             false;
420         actions_model->annotation_actions_spec->only_until_last_sent = true;
421         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
422         actions_model->annotation_actions_spec->max_history_from_last_person =
423             1;
424       },
425       unilib_.get());
426   EXPECT_THAT(response.actions, SizeIs(1));
427   EXPECT_EQ(response.actions[0].type, "track_flight");
428 }
429 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastPerson)430 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
431   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
432       [](ActionsModelT* actions_model) {
433         actions_model->annotation_actions_spec->include_local_user_messages =
434             false;
435         actions_model->annotation_actions_spec->only_until_last_sent = true;
436         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
437         actions_model->annotation_actions_spec->max_history_from_last_person =
438             3;
439       },
440       unilib_.get());
441   EXPECT_THAT(response.actions, SizeIs(2));
442   EXPECT_EQ(response.actions[0].type, "track_flight");
443   EXPECT_EQ(response.actions[1].type, "send_email");
444 }
445 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAny)446 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
447   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
448       [](ActionsModelT* actions_model) {
449         actions_model->annotation_actions_spec->include_local_user_messages =
450             false;
451         actions_model->annotation_actions_spec->only_until_last_sent = true;
452         actions_model->annotation_actions_spec->max_history_from_any_person = 2;
453         actions_model->annotation_actions_spec->max_history_from_last_person =
454             1;
455       },
456       unilib_.get());
457   EXPECT_THAT(response.actions, SizeIs(2));
458   EXPECT_EQ(response.actions[0].type, "track_flight");
459   EXPECT_EQ(response.actions[1].type, "send_email");
460 }
461 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessages)462 TEST_F(ActionsSuggestionsTest,
463        SuggestsActionsWithAnnotationsFromAnyManyMessages) {
464   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
465       [](ActionsModelT* actions_model) {
466         actions_model->annotation_actions_spec->include_local_user_messages =
467             false;
468         actions_model->annotation_actions_spec->only_until_last_sent = true;
469         actions_model->annotation_actions_spec->max_history_from_any_person = 3;
470         actions_model->annotation_actions_spec->max_history_from_last_person =
471             1;
472       },
473       unilib_.get());
474   EXPECT_THAT(response.actions, SizeIs(3));
475   EXPECT_EQ(response.actions[0].type, "track_flight");
476   EXPECT_EQ(response.actions[1].type, "send_email");
477   EXPECT_EQ(response.actions[2].type, "send_email");
478 }
479 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)480 TEST_F(ActionsSuggestionsTest,
481        SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
482   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
483       [](ActionsModelT* actions_model) {
484         actions_model->annotation_actions_spec->include_local_user_messages =
485             false;
486         actions_model->annotation_actions_spec->only_until_last_sent = true;
487         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
488         actions_model->annotation_actions_spec->max_history_from_last_person =
489             1;
490       },
491       unilib_.get());
492   EXPECT_THAT(response.actions, SizeIs(3));
493   EXPECT_EQ(response.actions[0].type, "track_flight");
494   EXPECT_EQ(response.actions[1].type, "send_email");
495   EXPECT_EQ(response.actions[2].type, "send_email");
496 }
497 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)498 TEST_F(ActionsSuggestionsTest,
499        SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
500   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
501       [](ActionsModelT* actions_model) {
502         actions_model->annotation_actions_spec->include_local_user_messages =
503             true;
504         actions_model->annotation_actions_spec->only_until_last_sent = false;
505         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
506         actions_model->annotation_actions_spec->max_history_from_last_person =
507             1;
508       },
509       unilib_.get());
510   EXPECT_THAT(response.actions, SizeIs(4));
511   EXPECT_EQ(response.actions[0].type, "track_flight");
512   EXPECT_EQ(response.actions[1].type, "send_email");
513   EXPECT_EQ(response.actions[2].type, "send_email");
514   EXPECT_EQ(response.actions[3].type, "send_email");
515 }
516 
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")517 void TestSuggestActionsWithThreshold(
518     const std::function<void(ActionsModelT*)>& set_value_fn,
519     const UniLib* unilib = nullptr, const int expected_size = 0,
520     const std::string& preconditions_overwrite = "") {
521   const std::string actions_model_string =
522       ReadFile(GetModelPath() + kModelFileName);
523   std::unique_ptr<ActionsModelT> actions_model =
524       UnPackActionsModel(actions_model_string.c_str());
525   set_value_fn(actions_model.get());
526   flatbuffers::FlatBufferBuilder builder;
527   FinishActionsModelBuffer(builder,
528                            ActionsModel::Pack(builder, actions_model.get()));
529   std::unique_ptr<ActionsSuggestions> actions_suggestions =
530       ActionsSuggestions::FromUnownedBuffer(
531           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
532           builder.GetSize(), unilib, preconditions_overwrite);
533   ASSERT_TRUE(actions_suggestions);
534   const ActionsSuggestionsResponse response =
535       actions_suggestions->SuggestActions(
536           {{{/*user_id=*/1, "I have the low-ground. Where are you?",
537              /*reference_time_ms_utc=*/0,
538              /*reference_timezone=*/"Europe/Zurich",
539              /*annotations=*/{}, /*locales=*/"en"}}});
540   EXPECT_LE(response.actions.size(), expected_size);
541 }
542 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithTriggeringScore)543 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
544   TestSuggestActionsWithThreshold(
545       [](ActionsModelT* actions_model) {
546         actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
547       },
548       unilib_.get(),
549       /*expected_size=*/1 /*no smart reply, only actions*/
550   );
551 }
552 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinReplyScore)553 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
554   TestSuggestActionsWithThreshold(
555       [](ActionsModelT* actions_model) {
556         actions_model->preconditions->min_reply_score_threshold = 1.0;
557       },
558       unilib_.get(),
559       /*expected_size=*/1 /*no smart reply, only actions*/
560   );
561 }
562 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithSensitiveTopicScore)563 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
564   TestSuggestActionsWithThreshold(
565       [](ActionsModelT* actions_model) {
566         actions_model->preconditions->max_sensitive_topic_score = 0.0;
567       },
568       unilib_.get(),
569       /*expected_size=*/4 /* no sensitive prediction in test model*/);
570 }
571 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMaxInputLength)572 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
573   TestSuggestActionsWithThreshold(
574       [](ActionsModelT* actions_model) {
575         actions_model->preconditions->max_input_length = 0;
576       },
577       unilib_.get());
578 }
579 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinInputLength)580 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
581   TestSuggestActionsWithThreshold(
582       [](ActionsModelT* actions_model) {
583         actions_model->preconditions->min_input_length = 100;
584       },
585       unilib_.get());
586 }
587 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithPreconditionsOverwrite)588 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
589   TriggeringPreconditionsT preconditions_overwrite;
590   preconditions_overwrite.max_input_length = 0;
591   flatbuffers::FlatBufferBuilder builder;
592   builder.Finish(
593       TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
594   TestSuggestActionsWithThreshold(
595       // Keep model untouched.
596       [](ActionsModelT* actions_model) {}, unilib_.get(),
597       /*expected_size=*/0,
598       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
599                   builder.GetSize()));
600 }
601 
602 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidence)603 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
604   TestSuggestActionsWithThreshold(
605       [](ActionsModelT* actions_model) {
606         actions_model->preconditions->suppress_on_low_confidence_input = true;
607         actions_model->low_confidence_rules.reset(new RulesModelT);
608         actions_model->low_confidence_rules->regex_rule.emplace_back(
609             new RulesModel_::RegexRuleT);
610         actions_model->low_confidence_rules->regex_rule.back()->pattern =
611             "low-ground";
612       },
613       unilib_.get());
614 }
615 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutput)616 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
617   const std::string actions_model_string =
618       ReadFile(GetModelPath() + kModelFileName);
619   std::unique_ptr<ActionsModelT> actions_model =
620       UnPackActionsModel(actions_model_string.c_str());
621   // Add custom triggering rule.
622   actions_model->rules.reset(new RulesModelT());
623   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
624   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
625   rule->pattern = "^(?i:hello\\s(there))$";
626   {
627     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
628         new RulesModel_::RuleActionSpecT);
629     rule_action->action.reset(new ActionSuggestionSpecT);
630     rule_action->action->type = "text_reply";
631     rule_action->action->response_text = "General Desaster!";
632     rule_action->action->score = 1.0f;
633     rule_action->action->priority_score = 1.0f;
634     rule->actions.push_back(std::move(rule_action));
635   }
636   {
637     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
638         new RulesModel_::RuleActionSpecT);
639     rule_action->action.reset(new ActionSuggestionSpecT);
640     rule_action->action->type = "text_reply";
641     rule_action->action->response_text = "General Kenobi!";
642     rule_action->action->score = 1.0f;
643     rule_action->action->priority_score = 1.0f;
644     rule->actions.push_back(std::move(rule_action));
645   }
646 
647   // Add input-output low confidence rule.
648   actions_model->preconditions->suppress_on_low_confidence_input = true;
649   actions_model->low_confidence_rules.reset(new RulesModelT);
650   actions_model->low_confidence_rules->regex_rule.emplace_back(
651       new RulesModel_::RegexRuleT);
652   actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
653   actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
654       "(?i:desaster)";
655 
656   flatbuffers::FlatBufferBuilder builder;
657   FinishActionsModelBuffer(builder,
658                            ActionsModel::Pack(builder, actions_model.get()));
659   std::unique_ptr<ActionsSuggestions> actions_suggestions =
660       ActionsSuggestions::FromUnownedBuffer(
661           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
662           builder.GetSize(), unilib_.get());
663   ASSERT_TRUE(actions_suggestions);
664   const ActionsSuggestionsResponse response =
665       actions_suggestions->SuggestActions(
666           {{{/*user_id=*/1, "hello there",
667              /*reference_time_ms_utc=*/0,
668              /*reference_timezone=*/"Europe/Zurich",
669              /*annotations=*/{}, /*locales=*/"en"}}});
670   ASSERT_GE(response.actions.size(), 1);
671   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
672 }
673 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutputOverwrite)674 TEST_F(ActionsSuggestionsTest,
675        SuggestsActionsLowConfidenceInputOutputOverwrite) {
676   const std::string actions_model_string =
677       ReadFile(GetModelPath() + kModelFileName);
678   std::unique_ptr<ActionsModelT> actions_model =
679       UnPackActionsModel(actions_model_string.c_str());
680   actions_model->low_confidence_rules.reset();
681 
682   // Add custom triggering rule.
683   actions_model->rules.reset(new RulesModelT());
684   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
685   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
686   rule->pattern = "^(?i:hello\\s(there))$";
687   {
688     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
689         new RulesModel_::RuleActionSpecT);
690     rule_action->action.reset(new ActionSuggestionSpecT);
691     rule_action->action->type = "text_reply";
692     rule_action->action->response_text = "General Desaster!";
693     rule_action->action->score = 1.0f;
694     rule_action->action->priority_score = 1.0f;
695     rule->actions.push_back(std::move(rule_action));
696   }
697   {
698     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
699         new RulesModel_::RuleActionSpecT);
700     rule_action->action.reset(new ActionSuggestionSpecT);
701     rule_action->action->type = "text_reply";
702     rule_action->action->response_text = "General Kenobi!";
703     rule_action->action->score = 1.0f;
704     rule_action->action->priority_score = 1.0f;
705     rule->actions.push_back(std::move(rule_action));
706   }
707 
708   // Add custom triggering rule via overwrite.
709   actions_model->preconditions->low_confidence_rules.reset();
710   TriggeringPreconditionsT preconditions;
711   preconditions.suppress_on_low_confidence_input = true;
712   preconditions.low_confidence_rules.reset(new RulesModelT);
713   preconditions.low_confidence_rules->regex_rule.emplace_back(
714       new RulesModel_::RegexRuleT);
715   preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
716   preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
717       "(?i:desaster)";
718   flatbuffers::FlatBufferBuilder preconditions_builder;
719   preconditions_builder.Finish(
720       TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
721   std::string serialize_preconditions = std::string(
722       reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
723       preconditions_builder.GetSize());
724 
725   flatbuffers::FlatBufferBuilder builder;
726   FinishActionsModelBuffer(builder,
727                            ActionsModel::Pack(builder, actions_model.get()));
728   std::unique_ptr<ActionsSuggestions> actions_suggestions =
729       ActionsSuggestions::FromUnownedBuffer(
730           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
731           builder.GetSize(), unilib_.get(), serialize_preconditions);
732 
733   ASSERT_TRUE(actions_suggestions);
734   const ActionsSuggestionsResponse response =
735       actions_suggestions->SuggestActions(
736           {{{/*user_id=*/1, "hello there",
737              /*reference_time_ms_utc=*/0,
738              /*reference_timezone=*/"Europe/Zurich",
739              /*annotations=*/{}, /*locales=*/"en"}}});
740   ASSERT_GE(response.actions.size(), 1);
741   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
742 }
743 #endif
744 
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)745 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
746   const std::string actions_model_string =
747       ReadFile(GetModelPath() + kModelFileName);
748   std::unique_ptr<ActionsModelT> actions_model =
749       UnPackActionsModel(actions_model_string.c_str());
750 
751   // Don't test if no sensitivity score is produced
752   if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
753     return;
754   }
755 
756   actions_model->preconditions->max_sensitive_topic_score = 0.0;
757   actions_model->preconditions->suppress_on_sensitive_topic = true;
758   flatbuffers::FlatBufferBuilder builder;
759   FinishActionsModelBuffer(builder,
760                            ActionsModel::Pack(builder, actions_model.get()));
761   std::unique_ptr<ActionsSuggestions> actions_suggestions =
762       ActionsSuggestions::FromUnownedBuffer(
763           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
764           builder.GetSize(), unilib_.get());
765   AnnotatedSpan annotation;
766   annotation.span = {11, 15};
767   annotation.classification = {
768       ClassificationResult(Collections::Address(), 1.0)};
769   const ActionsSuggestionsResponse response =
770       actions_suggestions->SuggestActions(
771           {{{/*user_id=*/1, "are you at home?",
772              /*reference_time_ms_utc=*/0,
773              /*reference_timezone=*/"Europe/Zurich",
774              /*annotations=*/{annotation},
775              /*locales=*/"en"}}});
776   EXPECT_THAT(response.actions, testing::IsEmpty());
777 }
778 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithLongerConversation)779 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
780   const std::string actions_model_string =
781       ReadFile(GetModelPath() + kModelFileName);
782   std::unique_ptr<ActionsModelT> actions_model =
783       UnPackActionsModel(actions_model_string.c_str());
784 
785   // Allow a larger conversation context.
786   actions_model->max_conversation_history_length = 10;
787 
788   flatbuffers::FlatBufferBuilder builder;
789   FinishActionsModelBuffer(builder,
790                            ActionsModel::Pack(builder, actions_model.get()));
791   std::unique_ptr<ActionsSuggestions> actions_suggestions =
792       ActionsSuggestions::FromUnownedBuffer(
793           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
794           builder.GetSize(), unilib_.get());
795   AnnotatedSpan annotation;
796   annotation.span = {11, 15};
797   annotation.classification = {
798       ClassificationResult(Collections::Address(), 1.0)};
799   const ActionsSuggestionsResponse response =
800       actions_suggestions->SuggestActions(
801           {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
802              /*reference_time_ms_utc=*/10000,
803              /*reference_timezone=*/"Europe/Zurich",
804              /*annotations=*/{}, /*locales=*/"en"},
805             {/*user_id=*/1, "good! are you at home?",
806              /*reference_time_ms_utc=*/15000,
807              /*reference_timezone=*/"Europe/Zurich",
808              /*annotations=*/{annotation},
809              /*locales=*/"en"}}});
810   ASSERT_GE(response.actions.size(), 1);
811   EXPECT_EQ(response.actions[0].type, "view_map");
812   EXPECT_EQ(response.actions[0].score, 1.0);
813 }
814 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromTF2MultiTaskModel)815 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
816   std::unique_ptr<ActionsSuggestions> actions_suggestions =
817       LoadTestModel(kMultiTaskTF2TestModelFileName);
818   const ActionsSuggestionsResponse response =
819       actions_suggestions->SuggestActions(
820           {{{/*user_id=*/1, "Hello how are you",
821              /*reference_time_ms_utc=*/0,
822              /*reference_timezone=*/"Europe/Zurich",
823              /*annotations=*/{},
824              /*locales=*/"en"}}});
825   EXPECT_EQ(response.actions.size(), 4);
826   EXPECT_EQ(response.actions[0].response_text, "Okay");
827   EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
828   EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
829 }
830 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromPhoneGrammarAnnotations)831 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
832   std::unique_ptr<ActionsSuggestions> actions_suggestions =
833       LoadTestModel(kModelGrammarFileName);
834   AnnotatedSpan annotation;
835   annotation.span = {11, 15};
836   annotation.classification = {ClassificationResult("phone", 0.0)};
837   const ActionsSuggestionsResponse response =
838       actions_suggestions->SuggestActions(
839           {{{/*user_id=*/1, "Contact us at: *1234",
840              /*reference_time_ms_utc=*/0,
841              /*reference_timezone=*/"Europe/Zurich",
842              /*annotations=*/{annotation},
843              /*locales=*/"en"}}});
844   ASSERT_GE(response.actions.size(), 1);
845   EXPECT_EQ(response.actions.front().type, "call_phone");
846   EXPECT_EQ(response.actions.front().score, 0.0);
847   EXPECT_EQ(response.actions.front().priority_score, 0.0);
848   EXPECT_EQ(response.actions.front().annotations.size(), 1);
849   EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
850   EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
851 }
852 
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)853 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
854   std::unique_ptr<ActionsSuggestions> actions_suggestions =
855       LoadTestModel(kModelFileName);
856   AnnotatedSpan annotation;
857   annotation.span = {8, 12};
858   annotation.classification = {
859       ClassificationResult(Collections::Flight(), 1.0)};
860 
861   const ActionsSuggestionsResponse response =
862       actions_suggestions->SuggestActions(
863           {{{/*user_id=*/1, "I'm on LX38?",
864              /*reference_time_ms_utc=*/0,
865              /*reference_timezone=*/"Europe/Zurich",
866              /*annotations=*/{annotation},
867              /*locales=*/"en"}}});
868 
869   ASSERT_GE(response.actions.size(), 2);
870   EXPECT_EQ(response.actions[0].type, "track_flight");
871   EXPECT_EQ(response.actions[0].score, 1.0);
872   EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
873   EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
874   EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
875 }
876 
877 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)878 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
879   const std::string actions_model_string =
880       ReadFile(GetModelPath() + kModelFileName);
881   std::unique_ptr<ActionsModelT> actions_model =
882       UnPackActionsModel(actions_model_string.c_str());
883   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
884 
885   actions_model->rules.reset(new RulesModelT());
886   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
887   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
888   rule->pattern = "^(?i:hello\\s(there))$";
889   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
890   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
891   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
892   action->type = "text_reply";
893   action->response_text = "General Kenobi!";
894   action->score = 1.0f;
895   action->priority_score = 1.0f;
896 
897   // Set capturing groups for entity data.
898   rule->actions.back()->capturing_group.emplace_back(
899       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
900   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
901       rule->actions.back()->capturing_group.back().get();
902   greeting_group->group_id = 0;
903   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
904   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
905   greeting_group->entity_field->field.back()->field_name = "greeting";
906   rule->actions.back()->capturing_group.emplace_back(
907       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
908   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
909       rule->actions.back()->capturing_group.back().get();
910   location_group->group_id = 1;
911   location_group->entity_field.reset(new FlatbufferFieldPathT);
912   location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
913   location_group->entity_field->field.back()->field_name = "location";
914 
915   // Set test entity data schema.
916   SetTestEntityDataSchema(actions_model.get());
917 
918   // Use meta data to generate custom serialized entity data.
919   MutableFlatbufferBuilder entity_data_builder(
920       flatbuffers::GetRoot<reflection::Schema>(
921           actions_model->actions_entity_data_schema.data()));
922   std::unique_ptr<MutableFlatbuffer> entity_data =
923       entity_data_builder.NewRoot();
924   entity_data->Set("person", "Kenobi");
925   action->serialized_entity_data = entity_data->Serialize();
926 
927   flatbuffers::FlatBufferBuilder builder;
928   FinishActionsModelBuffer(builder,
929                            ActionsModel::Pack(builder, actions_model.get()));
930   std::unique_ptr<ActionsSuggestions> actions_suggestions =
931       ActionsSuggestions::FromUnownedBuffer(
932           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
933           builder.GetSize(), unilib_.get());
934 
935   const ActionsSuggestionsResponse response =
936       actions_suggestions->SuggestActions(
937           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
938              /*reference_timezone=*/"Europe/Zurich",
939              /*annotations=*/{}, /*locales=*/"en"}}});
940   EXPECT_GE(response.actions.size(), 1);
941   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
942 
943   // Check entity data.
944   const flatbuffers::Table* entity =
945       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
946           response.actions[0].serialized_entity_data.data()));
947   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
948             "hello there");
949   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
950             "there");
951   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
952             "Kenobi");
953 }
954 
TEST_F(ActionsSuggestionsTest,CreateActionsFromRulesWithNormalization)955 TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
956   const std::string actions_model_string =
957       ReadFile(GetModelPath() + kModelFileName);
958   std::unique_ptr<ActionsModelT> actions_model =
959       UnPackActionsModel(actions_model_string.c_str());
960   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
961 
962   actions_model->rules.reset(new RulesModelT());
963   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
964   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
965   rule->pattern = "^(?i:hello\\sthere)$";
966   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
967   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
968   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
969   action->type = "text_reply";
970   action->response_text = "General Kenobi!";
971   action->score = 1.0f;
972   action->priority_score = 1.0f;
973 
974   // Set capturing groups for entity data.
975   rule->actions.back()->capturing_group.emplace_back(
976       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
977   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
978       rule->actions.back()->capturing_group.back().get();
979   greeting_group->group_id = 0;
980   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
981   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
982   greeting_group->entity_field->field.back()->field_name = "greeting";
983   greeting_group->normalization_options.reset(new NormalizationOptionsT);
984   greeting_group->normalization_options->codepointwise_normalization =
985       NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
986       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
987 
988   // Set test entity data schema.
989   SetTestEntityDataSchema(actions_model.get());
990 
991   flatbuffers::FlatBufferBuilder builder;
992   FinishActionsModelBuffer(builder,
993                            ActionsModel::Pack(builder, actions_model.get()));
994   std::unique_ptr<ActionsSuggestions> actions_suggestions =
995       ActionsSuggestions::FromUnownedBuffer(
996           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
997           builder.GetSize(), unilib_.get());
998 
999   const ActionsSuggestionsResponse response =
1000       actions_suggestions->SuggestActions(
1001           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
1002              /*reference_timezone=*/"Europe/Zurich",
1003              /*annotations=*/{}, /*locales=*/"en"}}});
1004   EXPECT_GE(response.actions.size(), 1);
1005   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
1006 
1007   // Check entity data.
1008   const flatbuffers::Table* entity =
1009       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
1010           response.actions[0].serialized_entity_data.data()));
1011   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
1012             "HELLOTHERE");
1013 }
1014 
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)1015 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
1016   const std::string actions_model_string =
1017       ReadFile(GetModelPath() + kModelFileName);
1018   std::unique_ptr<ActionsModelT> actions_model =
1019       UnPackActionsModel(actions_model_string.c_str());
1020   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1021 
1022   actions_model->rules.reset(new RulesModelT());
1023   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1024   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1025   rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
1026   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1027 
1028   // Set capturing groups for entity data.
1029   rule->actions.back()->capturing_group.emplace_back(
1030       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1031   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1032       rule->actions.back()->capturing_group.back().get();
1033   code_group->group_id = 1;
1034   code_group->text_reply.reset(new ActionSuggestionSpecT);
1035   code_group->text_reply->score = 1.0f;
1036   code_group->text_reply->priority_score = 1.0f;
1037   code_group->normalization_options.reset(new NormalizationOptionsT);
1038   code_group->normalization_options->codepointwise_normalization =
1039       NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
1040 
1041   flatbuffers::FlatBufferBuilder builder;
1042   FinishActionsModelBuffer(builder,
1043                            ActionsModel::Pack(builder, actions_model.get()));
1044   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1045       ActionsSuggestions::FromUnownedBuffer(
1046           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1047           builder.GetSize(), unilib_.get());
1048 
1049   const ActionsSuggestionsResponse response =
1050       actions_suggestions->SuggestActions(
1051           {{{/*user_id=*/1,
1052              "visit test.com or reply STOP to cancel your subscription",
1053              /*reference_time_ms_utc=*/0,
1054              /*reference_timezone=*/"Europe/Zurich",
1055              /*annotations=*/{}, /*locales=*/"en"}}});
1056   EXPECT_GE(response.actions.size(), 1);
1057   EXPECT_EQ(response.actions[0].response_text, "stop");
1058 }
1059 
TEST_F(ActionsSuggestionsTest,CreatesActionsFromGrammarRules)1060 TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
1061   const std::string actions_model_string =
1062       ReadFile(GetModelPath() + kModelFileName);
1063   std::unique_ptr<ActionsModelT> actions_model =
1064       UnPackActionsModel(actions_model_string.c_str());
1065   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1066 
1067   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1068 
1069   // Set tokenizer options.
1070   RulesModel_::GrammarRulesT* action_grammar_rules =
1071       actions_model->rules->grammar_rules.get();
1072   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1073   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1074   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1075       false;
1076 
1077   // Setup test rules.
1078   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1079   grammar::LocaleShardMap locale_shard_map =
1080       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1081   grammar::Rules rules(locale_shard_map);
1082   rules.Add(
1083       "<knock>", {"<^>", "ventura", "!?", "<$>"},
1084       /*callback=*/
1085       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1086       /*callback_param=*/0);
1087   rules.Finalize().Serialize(/*include_debug_information=*/false,
1088                              action_grammar_rules->rules.get());
1089   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1090   RulesModel_::RuleActionSpecT* actions_spec =
1091       action_grammar_rules->actions.back().get();
1092   actions_spec->action.reset(new ActionSuggestionSpecT);
1093   actions_spec->action->response_text = "Yes, Satan?";
1094   actions_spec->action->priority_score = 1.0;
1095   actions_spec->action->score = 1.0;
1096   actions_spec->action->type = "text_reply";
1097   action_grammar_rules->rule_match.emplace_back(
1098       new RulesModel_::GrammarRules_::RuleMatchT);
1099   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1100 
1101   flatbuffers::FlatBufferBuilder builder;
1102   FinishActionsModelBuffer(builder,
1103                            ActionsModel::Pack(builder, actions_model.get()));
1104   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1105       ActionsSuggestions::FromUnownedBuffer(
1106           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1107           builder.GetSize(), unilib_.get());
1108 
1109   const ActionsSuggestionsResponse response =
1110       actions_suggestions->SuggestActions(
1111           {{{/*user_id=*/1, "Ventura!",
1112              /*reference_time_ms_utc=*/0,
1113              /*reference_timezone=*/"Europe/Zurich",
1114              /*annotations=*/{}, /*locales=*/"en"}}});
1115 
1116   EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
1117 }
1118 
1119 #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest,CreatesActionsWithAnnotationsFromGrammarRules)1120 TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
1121   std::unique_ptr<Annotator> annotator =
1122       Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
1123   const std::string actions_model_string =
1124       ReadFile(GetModelPath() + kModelFileName);
1125   std::unique_ptr<ActionsModelT> actions_model =
1126       UnPackActionsModel(actions_model_string.c_str());
1127   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1128 
1129   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1130 
1131   // Set tokenizer options.
1132   RulesModel_::GrammarRulesT* action_grammar_rules =
1133       actions_model->rules->grammar_rules.get();
1134   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1135   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1136   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1137       false;
1138 
1139   // Setup test rules.
1140   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1141   grammar::LocaleShardMap locale_shard_map =
1142       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1143   grammar::Rules rules(locale_shard_map);
1144   rules.Add(
1145       "<event>", {"it", "is", "at", "<time>"},
1146       /*callback=*/
1147       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1148       /*callback_param=*/0);
1149   rules.BindAnnotation("<time>", "time");
1150   rules.AddAnnotation("datetime");
1151   rules.Finalize().Serialize(/*include_debug_information=*/false,
1152                              action_grammar_rules->rules.get());
1153   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1154   RulesModel_::RuleActionSpecT* actions_spec =
1155       action_grammar_rules->actions.back().get();
1156   actions_spec->action.reset(new ActionSuggestionSpecT);
1157   actions_spec->action->priority_score = 1.0;
1158   actions_spec->action->score = 1.0;
1159   actions_spec->action->type = "create_event";
1160   action_grammar_rules->rule_match.emplace_back(
1161       new RulesModel_::GrammarRules_::RuleMatchT);
1162   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1163 
1164   flatbuffers::FlatBufferBuilder builder;
1165   FinishActionsModelBuffer(builder,
1166                            ActionsModel::Pack(builder, actions_model.get()));
1167   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1168       ActionsSuggestions::FromUnownedBuffer(
1169           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1170           builder.GetSize(), unilib_.get());
1171 
1172   const ActionsSuggestionsResponse response =
1173       actions_suggestions->SuggestActions(
1174           {{{/*user_id=*/1, "it is at 10:30",
1175              /*reference_time_ms_utc=*/0,
1176              /*reference_timezone=*/"Europe/Zurich",
1177              /*annotations=*/{}, /*locales=*/"en"}}},
1178           annotator.get());
1179 
1180   EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
1181 }
1182 #endif
1183 
TEST_F(ActionsSuggestionsTest,DeduplicateActions)1184 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
1185   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1186       LoadTestModel(kModelFileName);
1187   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1188       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1189          /*reference_timezone=*/"Europe/Zurich",
1190          /*annotations=*/{}, /*locales=*/"en"}}});
1191 
1192   // Check that the location sharing model triggered.
1193   bool has_location_sharing_action = false;
1194   for (const ActionSuggestion& action : response.actions) {
1195     if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
1196       has_location_sharing_action = true;
1197       break;
1198     }
1199   }
1200   EXPECT_TRUE(has_location_sharing_action);
1201   const int num_actions = response.actions.size();
1202 
1203   // Add custom rule for location sharing.
1204   const std::string actions_model_string =
1205       ReadFile(GetModelPath() + kModelFileName);
1206   std::unique_ptr<ActionsModelT> actions_model =
1207       UnPackActionsModel(actions_model_string.c_str());
1208   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1209 
1210   actions_model->rules.reset(new RulesModelT());
1211   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1212   actions_model->rules->regex_rule.back()->pattern =
1213       "^(?i:where are you[.?]?)$";
1214   actions_model->rules->regex_rule.back()->actions.emplace_back(
1215       new RulesModel_::RuleActionSpecT);
1216   actions_model->rules->regex_rule.back()->actions.back()->action.reset(
1217       new ActionSuggestionSpecT);
1218   ActionSuggestionSpecT* action =
1219       actions_model->rules->regex_rule.back()->actions.back()->action.get();
1220   action->score = 1.0f;
1221   action->type = ActionsSuggestionsTypes::ShareLocation();
1222 
1223   flatbuffers::FlatBufferBuilder builder;
1224   FinishActionsModelBuffer(builder,
1225                            ActionsModel::Pack(builder, actions_model.get()));
1226   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1227       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1228       builder.GetSize(), unilib_.get());
1229 
1230   response = actions_suggestions->SuggestActions(
1231       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1232          /*reference_timezone=*/"Europe/Zurich",
1233          /*annotations=*/{}, /*locales=*/"en"}}});
1234   EXPECT_THAT(response.actions, SizeIs(num_actions));
1235 }
1236 
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)1237 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
1238   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1239       LoadTestModel(kModelFileName);
1240   AnnotatedSpan annotation;
1241   annotation.span = {7, 11};
1242   annotation.classification = {
1243       ClassificationResult(Collections::Flight(), 1.0)};
1244   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1245       {{{/*user_id=*/1, "I'm on LX38",
1246          /*reference_time_ms_utc=*/0,
1247          /*reference_timezone=*/"Europe/Zurich",
1248          /*annotations=*/{annotation},
1249          /*locales=*/"en"}}});
1250 
1251   // Check that the phone actions are present.
1252   EXPECT_GE(response.actions.size(), 1);
1253   EXPECT_EQ(response.actions[0].type, "track_flight");
1254 
1255   // Add custom rule.
1256   const std::string actions_model_string =
1257       ReadFile(GetModelPath() + kModelFileName);
1258   std::unique_ptr<ActionsModelT> actions_model =
1259       UnPackActionsModel(actions_model_string.c_str());
1260   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1261 
1262   actions_model->rules.reset(new RulesModelT());
1263   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1264   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1265   rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
1266   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1267   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
1268   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
1269   action->score = 1.0f;
1270   action->priority_score = 2.0f;
1271   action->type = "test_code";
1272   rule->actions.back()->capturing_group.emplace_back(
1273       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1274   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1275       rule->actions.back()->capturing_group.back().get();
1276   code_group->group_id = 1;
1277   code_group->annotation_name = "code";
1278   code_group->annotation_type = "code";
1279 
1280   flatbuffers::FlatBufferBuilder builder;
1281   FinishActionsModelBuffer(builder,
1282                            ActionsModel::Pack(builder, actions_model.get()));
1283   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1284       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1285       builder.GetSize(), unilib_.get());
1286 
1287   response = actions_suggestions->SuggestActions(
1288       {{{/*user_id=*/1, "I'm on LX38",
1289          /*reference_time_ms_utc=*/0,
1290          /*reference_timezone=*/"Europe/Zurich",
1291          /*annotations=*/{annotation},
1292          /*locales=*/"en"}}});
1293   EXPECT_GE(response.actions.size(), 1);
1294   EXPECT_EQ(response.actions[0].type, "test_code");
1295 }
1296 #endif
1297 
TEST_F(ActionsSuggestionsTest,RanksActions)1298 TEST_F(ActionsSuggestionsTest, RanksActions) {
1299   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1300       LoadTestModel(kModelFileName);
1301   std::vector<AnnotatedSpan> annotations(2);
1302   annotations[0].span = {11, 15};
1303   annotations[0].classification = {ClassificationResult("address", 1.0)};
1304   annotations[1].span = {19, 23};
1305   annotations[1].classification = {ClassificationResult("address", 2.0)};
1306   const ActionsSuggestionsResponse response =
1307       actions_suggestions->SuggestActions(
1308           {{{/*user_id=*/1, "are you at home or work?",
1309              /*reference_time_ms_utc=*/0,
1310              /*reference_timezone=*/"Europe/Zurich",
1311              /*annotations=*/annotations,
1312              /*locales=*/"en"}}});
1313   EXPECT_GE(response.actions.size(), 2);
1314   EXPECT_EQ(response.actions[0].type, "view_map");
1315   EXPECT_EQ(response.actions[0].score, 2.0);
1316   EXPECT_EQ(response.actions[1].type, "view_map");
1317   EXPECT_EQ(response.actions[1].score, 1.0);
1318 }
1319 
TEST_F(ActionsSuggestionsTest,VisitActionsModel)1320 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
1321   EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
1322                                       [](const ActionsModel* model) {
1323                                         if (model == nullptr) {
1324                                           return false;
1325                                         }
1326                                         return true;
1327                                       }));
1328   EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
1329                                        [](const ActionsModel* model) {
1330                                          if (model == nullptr) {
1331                                            return false;
1332                                          }
1333                                          return true;
1334                                        }));
1335 }
1336 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithHashGramModel)1337 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
1338   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1339       LoadHashGramTestModel();
1340   ASSERT_TRUE(actions_suggestions != nullptr);
1341   {
1342     const ActionsSuggestionsResponse response =
1343         actions_suggestions->SuggestActions(
1344             {{{/*user_id=*/1, "hello",
1345                /*reference_time_ms_utc=*/0,
1346                /*reference_timezone=*/"Europe/Zurich",
1347                /*annotations=*/{},
1348                /*locales=*/"en"}}});
1349     EXPECT_THAT(response.actions, testing::IsEmpty());
1350   }
1351   {
1352     const ActionsSuggestionsResponse response =
1353         actions_suggestions->SuggestActions(
1354             {{{/*user_id=*/1, "where are you",
1355                /*reference_time_ms_utc=*/0,
1356                /*reference_timezone=*/"Europe/Zurich",
1357                /*annotations=*/{},
1358                /*locales=*/"en"}}});
1359     EXPECT_THAT(
1360         response.actions,
1361         ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1362   }
1363   {
1364     const ActionsSuggestionsResponse response =
1365         actions_suggestions->SuggestActions(
1366             {{{/*user_id=*/1, "do you know johns number",
1367                /*reference_time_ms_utc=*/0,
1368                /*reference_timezone=*/"Europe/Zurich",
1369                /*annotations=*/{},
1370                /*locales=*/"en"}}});
1371     EXPECT_THAT(
1372         response.actions,
1373         ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1374   }
1375 }
1376 
1377 // Test class to expose token embedding methods for testing.
1378 class TestingMessageEmbedder : private ActionsSuggestions {
1379  public:
1380   explicit TestingMessageEmbedder(const ActionsModel* model);
1381 
1382   using ActionsSuggestions::EmbedAndFlattenTokens;
1383   using ActionsSuggestions::EmbedTokensPerMessage;
1384 
1385  protected:
1386   // EmbeddingExecutor that always returns features based on
1387   // the id of the sparse features.
1388   class FakeEmbeddingExecutor : public EmbeddingExecutor {
1389    public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1390     bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1391                       const int dest_size) const override {
1392       TC3_CHECK_GE(dest_size, 1);
1393       EXPECT_EQ(sparse_features.size(), 1);
1394       dest[0] = sparse_features.data()[0];
1395       return true;
1396     }
1397   };
1398 
1399   std::unique_ptr<UniLib> unilib_;
1400 };
1401 
TestingMessageEmbedder(const ActionsModel * model)1402 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
1403     : unilib_(CreateUniLibForTesting()) {
1404   model_ = model;
1405   const ActionsTokenFeatureProcessorOptions* options =
1406       model->feature_processor_options();
1407   feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
1408   embedding_executor_.reset(new FakeEmbeddingExecutor());
1409   EXPECT_TRUE(
1410       EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1411   EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1412   EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1413   token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1414   EXPECT_EQ(token_embedding_size_, 1);
1415 }
1416 
1417 class EmbeddingTest : public testing::Test {
1418  protected:
EmbeddingTest()1419   explicit EmbeddingTest() {
1420     model_.feature_processor_options.reset(
1421         new ActionsTokenFeatureProcessorOptionsT);
1422     options_ = model_.feature_processor_options.get();
1423     options_->chargram_orders = {1};
1424     options_->num_buckets = 1000;
1425     options_->embedding_size = 1;
1426     options_->start_token_id = 0;
1427     options_->end_token_id = 1;
1428     options_->padding_token_id = 2;
1429     options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1430   }
1431 
CreateTestingMessageEmbedder()1432   TestingMessageEmbedder CreateTestingMessageEmbedder() {
1433     flatbuffers::FlatBufferBuilder builder;
1434     FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1435     buffer_ = builder.Release();
1436     return TestingMessageEmbedder(
1437         flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1438   }
1439 
1440   flatbuffers::DetachedBuffer buffer_;
1441   ActionsModelT model_;
1442   ActionsTokenFeatureProcessorOptionsT* options_;
1443 };
1444 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1445 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1446   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1447   std::vector<std::vector<Token>> tokens = {
1448       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1449   std::vector<float> embeddings;
1450   int max_num_tokens_per_message = 0;
1451 
1452   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1453                                              &max_num_tokens_per_message));
1454 
1455   EXPECT_EQ(max_num_tokens_per_message, 3);
1456   EXPECT_EQ(embeddings.size(), 3);
1457   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1458                                      options_->num_buckets));
1459   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1460                                      options_->num_buckets));
1461   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1462                                      options_->num_buckets));
1463 }
1464 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1465 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1466   options_->min_num_tokens_per_message = 5;
1467   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1468   std::vector<std::vector<Token>> tokens = {
1469       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1470   std::vector<float> embeddings;
1471   int max_num_tokens_per_message = 0;
1472 
1473   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1474                                              &max_num_tokens_per_message));
1475 
1476   EXPECT_EQ(max_num_tokens_per_message, 5);
1477   EXPECT_EQ(embeddings.size(), 5);
1478   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1479                                      options_->num_buckets));
1480   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1481                                      options_->num_buckets));
1482   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1483                                      options_->num_buckets));
1484   EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
1485   EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
1486 }
1487 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1488 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1489   options_->max_num_tokens_per_message = 2;
1490   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1491   std::vector<std::vector<Token>> tokens = {
1492       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1493   std::vector<float> embeddings;
1494   int max_num_tokens_per_message = 0;
1495 
1496   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1497                                              &max_num_tokens_per_message));
1498 
1499   EXPECT_EQ(max_num_tokens_per_message, 2);
1500   EXPECT_EQ(embeddings.size(), 2);
1501   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1502                                      options_->num_buckets));
1503   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1504                                      options_->num_buckets));
1505 }
1506 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1507 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1508   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1509   std::vector<std::vector<Token>> tokens = {
1510       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1511       {Token("d", 0, 1), Token("e", 2, 3)}};
1512   std::vector<float> embeddings;
1513   int max_num_tokens_per_message = 0;
1514 
1515   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1516                                              &max_num_tokens_per_message));
1517 
1518   EXPECT_EQ(max_num_tokens_per_message, 3);
1519   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1520                                      options_->num_buckets));
1521   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1522                                      options_->num_buckets));
1523   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1524                                      options_->num_buckets));
1525   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1526                                      options_->num_buckets));
1527   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1528                                      options_->num_buckets));
1529   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1530 }
1531 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1532 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1533   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1534   std::vector<std::vector<Token>> tokens = {
1535       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1536   std::vector<float> embeddings;
1537   int total_token_count = 0;
1538 
1539   EXPECT_TRUE(
1540       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1541 
1542   EXPECT_EQ(total_token_count, 5);
1543   EXPECT_EQ(embeddings.size(), 5);
1544   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1545   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1546                                      options_->num_buckets));
1547   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1548                                      options_->num_buckets));
1549   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1550                                      options_->num_buckets));
1551   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1552 }
1553 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1554 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1555   options_->min_num_total_tokens = 7;
1556   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1557   std::vector<std::vector<Token>> tokens = {
1558       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1559   std::vector<float> embeddings;
1560   int total_token_count = 0;
1561 
1562   EXPECT_TRUE(
1563       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1564 
1565   EXPECT_EQ(total_token_count, 7);
1566   EXPECT_EQ(embeddings.size(), 7);
1567   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1568   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1569                                      options_->num_buckets));
1570   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1571                                      options_->num_buckets));
1572   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1573                                      options_->num_buckets));
1574   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1575   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1576   EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
1577 }
1578 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1579 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1580   options_->max_num_total_tokens = 3;
1581   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1582   std::vector<std::vector<Token>> tokens = {
1583       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1584   std::vector<float> embeddings;
1585   int total_token_count = 0;
1586 
1587   EXPECT_TRUE(
1588       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1589 
1590   EXPECT_EQ(total_token_count, 3);
1591   EXPECT_EQ(embeddings.size(), 3);
1592   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1593                                      options_->num_buckets));
1594   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1595                                      options_->num_buckets));
1596   EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
1597 }
1598 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1599 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1600   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1601   std::vector<std::vector<Token>> tokens = {
1602       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1603       {Token("d", 0, 1), Token("e", 2, 3)}};
1604   std::vector<float> embeddings;
1605   int total_token_count = 0;
1606 
1607   EXPECT_TRUE(
1608       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1609 
1610   EXPECT_EQ(total_token_count, 9);
1611   EXPECT_EQ(embeddings.size(), 9);
1612   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1613   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1614                                      options_->num_buckets));
1615   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1616                                      options_->num_buckets));
1617   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1618                                      options_->num_buckets));
1619   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1620   EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
1621   EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1622                                      options_->num_buckets));
1623   EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1624                                      options_->num_buckets));
1625   EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
1626 }
1627 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1628 TEST_F(EmbeddingTest,
1629        EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1630   options_->max_num_total_tokens = 7;
1631   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1632   std::vector<std::vector<Token>> tokens = {
1633       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1634       {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1635   std::vector<float> embeddings;
1636   int total_token_count = 0;
1637 
1638   EXPECT_TRUE(
1639       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1640 
1641   EXPECT_EQ(total_token_count, 7);
1642   EXPECT_EQ(embeddings.size(), 7);
1643   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1644                                      options_->num_buckets));
1645   EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
1646   EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
1647   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1648                                      options_->num_buckets));
1649   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1650                                      options_->num_buckets));
1651   EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1652                                      options_->num_buckets));
1653   EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
1654 }
1655 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDefault)1656 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
1657   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1658       LoadMultiTaskTestModel();
1659   const ActionsSuggestionsResponse response =
1660       actions_suggestions->SuggestActions(
1661           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1662              /*reference_timezone=*/"Europe/Zurich",
1663              /*annotations=*/{}, /*locales=*/"en"}}});
1664   EXPECT_EQ(response.actions.size(),
1665             11 /* 8 binary classification + 3 smart replies*/);
1666 }
1667 
1668 const float kDisableThresholdVal = 2.0;
1669 
1670 constexpr char kSpamThreshold[] = "spam_confidence_threshold";
1671 constexpr char kLocationThreshold[] = "location_confidence_threshold";
1672 constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
1673 constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
1674 constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
1675 constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
1676 constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
1677 constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
1678 constexpr char kDiversificationParm[] = "diversification_distance_threshold";
1679 constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
1680 
GetOptionsToDisableAllClassification()1681 ActionSuggestionOptions GetOptionsToDisableAllClassification() {
1682   ActionSuggestionOptions options;
1683   // Disable all classification heads.
1684   options.model_parameters.insert(
1685       {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1686   options.model_parameters.insert(
1687       {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1688   options.model_parameters.insert(
1689       {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1690   options.model_parameters.insert(
1691       {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1692   options.model_parameters.insert(
1693       {kRestaurantsThreshold,
1694        libtextclassifier3::Variant(kDisableThresholdVal)});
1695   options.model_parameters.insert(
1696       {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1697   options.model_parameters.insert(
1698       {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1699   options.model_parameters.insert(
1700       {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1701   return options;
1702 }
1703 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyOnly)1704 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
1705   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1706       LoadMultiTaskTestModel();
1707   const ActionSuggestionOptions options =
1708       GetOptionsToDisableAllClassification();
1709   const ActionsSuggestionsResponse response =
1710       actions_suggestions->SuggestActions(
1711           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1712              /*reference_timezone=*/"Europe/Zurich",
1713              /*annotations=*/{}, /*locales=*/"en"}}},
1714           /*annotator=*/nullptr, options);
1715   EXPECT_THAT(response.actions,
1716               ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
1717                           IsSmartReply("I'm home")));
1718   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1719 }
1720 
1721 const int kUserProfileSize = 1000;
1722 constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
1723 constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
1724 
GetOptionsForSmartReplyP13nModel()1725 ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
1726   ActionSuggestionOptions options;
1727   const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
1728   const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
1729   options.model_parameters.insert(
1730       {kUserProfileTokenIndex,
1731        libtextclassifier3::Variant(user_profile_token_indexes)});
1732   options.model_parameters.insert(
1733       {kUserProfileTokenWeight,
1734        libtextclassifier3::Variant(user_profile_token_weights)});
1735   return options;
1736 }
1737 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyP13n)1738 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
1739   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1740       LoadMultiTaskSrP13nTestModel();
1741   const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
1742   const ActionsSuggestionsResponse response =
1743       actions_suggestions->SuggestActions(
1744           {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
1745              /*reference_timezone=*/"Europe/Zurich",
1746              /*annotations=*/{}, /*locales=*/"en"}}},
1747           /*annotator=*/nullptr, options);
1748   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1749 }
1750 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation)1751 TEST_F(ActionsSuggestionsTest,
1752        MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
1753   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1754       LoadMultiTaskTestModel();
1755   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1756   options.model_parameters[kLocationThreshold] =
1757       libtextclassifier3::Variant(0.35f);
1758   options.model_parameters.insert(
1759       {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
1760   const ActionsSuggestionsResponse response =
1761       actions_suggestions->SuggestActions(
1762           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1763              /*reference_timezone=*/"Europe/Zurich",
1764              /*annotations=*/{}, /*locales=*/"en"}}},
1765           /*annotator=*/nullptr, options);
1766   EXPECT_THAT(
1767       response.actions,
1768       ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
1769                   IsSmartReply("Yes"), IsSmartReply("��")));
1770   EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
1771 }
1772 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder)1773 TEST_F(ActionsSuggestionsTest,
1774        MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
1775   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1776       LoadMultiTaskTestModel();
1777   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1778   options.model_parameters[kLocationThreshold] =
1779       libtextclassifier3::Variant(0.35f);
1780   // reminder head always trigger since the threshold is zero.
1781   options.model_parameters[kReminderThreshold] =
1782       libtextclassifier3::Variant(0.0f);
1783   options.model_parameters.insert(
1784       {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
1785   const ActionsSuggestionsResponse response =
1786       actions_suggestions->SuggestActions(
1787           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1788              /*reference_timezone=*/"Europe/Zurich",
1789              /*annotations=*/{}, /*locales=*/"en"}}},
1790           /*annotator=*/nullptr, options);
1791   EXPECT_THAT(
1792       response.actions,
1793       ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
1794                   IsSmartReply("Yes"),
1795                   /*Different emoji than previous test*/ IsSmartReply("��"),
1796                   IsActionOfType("REMINDER_INTENT")));
1797   EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
1798 }
1799 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromMultiTaskSrEmojiModel)1800 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
1801   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1802       LoadTestModel(kMultiTaskSrEmojiModelFileName);
1803 
1804   const ActionsSuggestionsResponse response =
1805       actions_suggestions->SuggestActions(
1806           {{{/*user_id=*/1, "hello?",
1807              /*reference_time_ms_utc=*/0,
1808              /*reference_timezone=*/"Europe/Zurich",
1809              /*annotations=*/{},
1810              /*locales=*/"en"}}});
1811   EXPECT_EQ(response.actions.size(), 5);
1812   EXPECT_EQ(response.actions[0].response_text, "��");
1813   EXPECT_EQ(response.actions[0].type, "text_reply");
1814   EXPECT_EQ(response.actions[1].response_text, "��");
1815   EXPECT_EQ(response.actions[1].type, "text_reply");
1816   EXPECT_EQ(response.actions[2].response_text, "Yes");
1817   EXPECT_EQ(response.actions[2].type, "text_reply");
1818 }
1819 
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelRemovesTextHeadEmoji)1820 TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
1821   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1822       LoadTestModel(kMultiTaskSrEmojiModelFileName);
1823 
1824   const ActionsSuggestionsResponse response =
1825       actions_suggestions->SuggestActions(
1826           {{{/*user_id=*/1, "a pleasure chatting",
1827              /*reference_time_ms_utc=*/0,
1828              /*reference_timezone=*/"Europe/Zurich",
1829              /*annotations=*/{},
1830              /*locales=*/"en"}}});
1831   EXPECT_EQ(response.actions.size(), 3);
1832   EXPECT_EQ(response.actions[0].response_text, "��");
1833   EXPECT_EQ(response.actions[0].type, "text_reply");
1834   EXPECT_EQ(response.actions[1].response_text, "��");
1835   EXPECT_EQ(response.actions[1].type, "text_reply");
1836   EXPECT_EQ(response.actions[2].response_text, "Okay");
1837   EXPECT_EQ(response.actions[2].type, "text_reply");
1838 }
1839 
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelUsesConcepts)1840 TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
1841   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1842       LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
1843 
1844   const ActionsSuggestionsResponse response =
1845       actions_suggestions->SuggestActions(
1846           {{{/*user_id=*/1, "i am tired",
1847              /*reference_time_ms_utc=*/0,
1848              /*reference_timezone=*/"Europe/Zurich",
1849              /*annotations=*/{},
1850              /*locales=*/"en"}}});
1851   std::vector<std::string> sigh_emojis = {"��", "��"};
1852 
1853   EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
1854                         response.actions[0].response_text) !=
1855               sigh_emojis.end());
1856   EXPECT_EQ(response.actions[0].type, "emoji_reply");
1857 }
1858 
TEST_F(ActionsSuggestionsTest,LiveRelayModel)1859 TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
1860   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1861       LoadTestModel(kLiveRelayTFLiteModelFileName);
1862   const ActionsSuggestionsResponse response =
1863       actions_suggestions->SuggestActions(
1864           {{{/*user_id=*/1, "Hi",
1865              /*reference_time_ms_utc=*/0,
1866              /*reference_timezone=*/"Europe/Zurich",
1867              /*annotations=*/{},
1868              /*locales=*/"en"}}});
1869   EXPECT_EQ(response.actions.size(), 3);
1870   EXPECT_EQ(response.actions[0].response_text, "Hi how are you doing");
1871   EXPECT_EQ(response.actions[0].type, "text_reply");
1872   EXPECT_EQ(response.actions[1].response_text, "Hi whats up");
1873   EXPECT_EQ(response.actions[1].type, "text_reply");
1874 }
1875 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromSensitiveTfLiteModel)1876 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
1877   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1878       LoadTestModel(kSensitiveTFliteModelFileName);
1879   const ActionsSuggestionsResponse response =
1880       actions_suggestions->SuggestActions(
1881           {{{/*user_id=*/1, "I want to kill myself",
1882              /*reference_time_ms_utc=*/0,
1883              /*reference_timezone=*/"Europe/Zurich",
1884              /*annotations=*/{},
1885              /*locales=*/"en"}}});
1886   EXPECT_EQ(response.actions.size(), 0);
1887   EXPECT_TRUE(response.is_sensitive);
1888   EXPECT_FALSE(response.output_filtered_low_confidence);
1889 }
1890 
1891 }  // namespace
1892 }  // namespace libtextclassifier3
1893