1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <fstream>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23
24 #include "actions/actions_model_generated.h"
25 #include "actions/test-utils.h"
26 #include "actions/zlib-utils.h"
27 #include "annotator/collections.h"
28 #include "annotator/types.h"
29 #include "utils/flatbuffers/flatbuffers.h"
30 #include "utils/flatbuffers/flatbuffers_generated.h"
31 #include "utils/flatbuffers/mutable.h"
32 #include "utils/grammar/utils/locale-shard-map.h"
33 #include "utils/grammar/utils/rules.h"
34 #include "utils/hash/farmhash.h"
35 #include "utils/jvm-test-utils.h"
36 #include "utils/test-data-test-utils.h"
37 #include "gmock/gmock.h"
38 #include "gtest/gtest.h"
39 #include "flatbuffers/flatbuffers.h"
40 #include "flatbuffers/reflection.h"
41
42 namespace libtextclassifier3 {
43 namespace {
44
45 using ::testing::ElementsAre;
46 using ::testing::FloatEq;
47 using ::testing::IsEmpty;
48 using ::testing::NotNull;
49 using ::testing::SizeIs;
50
51 constexpr char kModelFileName[] = "actions_suggestions_test.model";
52 constexpr char kModelGrammarFileName[] =
53 "actions_suggestions_grammar_test.model";
54 constexpr char kMultiTaskTF2TestModelFileName[] =
55 "actions_suggestions_test.multi_task_tf2_test.model";
56 constexpr char kMultiTaskModelFileName[] =
57 "actions_suggestions_test.multi_task_9heads.model";
58 constexpr char kHashGramModelFileName[] =
59 "actions_suggestions_test.hashgram.model";
60 constexpr char kMultiTaskSrP13nModelFileName[] =
61 "actions_suggestions_test.multi_task_sr_p13n.model";
62 constexpr char kMultiTaskSrEmojiModelFileName[] =
63 "actions_suggestions_test.multi_task_sr_emoji.model";
64 constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
65 "actions_suggestions_test.multi_task_sr_emoji_concept.model";
66 constexpr char kSensitiveTFliteModelFileName[] =
67 "actions_suggestions_test.sensitive_tflite.model";
68 constexpr char kLiveRelayTFLiteModelFileName[] =
69 "actions_suggestions_test.live_relay.model";
70
ReadFile(const std::string & file_name)71 std::string ReadFile(const std::string& file_name) {
72 std::ifstream file_stream(file_name);
73 return std::string(std::istreambuf_iterator<char>(file_stream), {});
74 }
75
GetModelPath()76 std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
77
78 class ActionsSuggestionsTest : public testing::Test {
79 protected:
ActionsSuggestionsTest()80 explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
LoadTestModel(const std::string model_file_name)81 std::unique_ptr<ActionsSuggestions> LoadTestModel(
82 const std::string model_file_name) {
83 return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
84 unilib_.get());
85 }
LoadHashGramTestModel()86 std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
87 return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
88 unilib_.get());
89 }
LoadMultiTaskTestModel()90 std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
91 return ActionsSuggestions::FromPath(
92 GetModelPath() + kMultiTaskModelFileName, unilib_.get());
93 }
94
LoadMultiTaskSrP13nTestModel()95 std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
96 return ActionsSuggestions::FromPath(
97 GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
98 }
99 std::unique_ptr<UniLib> unilib_;
100 };
101
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)102 TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
103 EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
104 }
105
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidInput)106 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
107 std::unique_ptr<ActionsSuggestions> actions_suggestions =
108 LoadTestModel(kModelFileName);
109 const ActionsSuggestionsResponse response =
110 actions_suggestions->SuggestActions(
111 {{{/*user_id=*/1, "Where are you?\xf0\x9f",
112 /*reference_time_ms_utc=*/0,
113 /*reference_timezone=*/"Europe/Zurich",
114 /*annotations=*/{}, /*locales=*/"en"}}});
115 EXPECT_THAT(response.actions, IsEmpty());
116 }
117
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidUtf8)118 TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
119 std::unique_ptr<ActionsSuggestions> actions_suggestions =
120 LoadTestModel(kModelFileName);
121
122 const ActionsSuggestionsResponse response =
123 actions_suggestions->SuggestActions(
124 {{{/*user_id=*/1,
125 "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
126 /*reference_time_ms_utc=*/0,
127 /*reference_timezone=*/"Europe/Zurich",
128 /*annotations=*/{}, /*locales=*/"en"}}});
129 EXPECT_THAT(response.actions, IsEmpty());
130 }
131
TEST_F(ActionsSuggestionsTest,SuggestsActions)132 TEST_F(ActionsSuggestionsTest, SuggestsActions) {
133 std::unique_ptr<ActionsSuggestions> actions_suggestions =
134 LoadTestModel(kModelFileName);
135 const ActionsSuggestionsResponse response =
136 actions_suggestions->SuggestActions(
137 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
138 /*reference_timezone=*/"Europe/Zurich",
139 /*annotations=*/{}, /*locales=*/"en"}}});
140 EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
141 }
142
TEST_F(ActionsSuggestionsTest,SuggestsNoActionsForUnknownLocale)143 TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
144 std::unique_ptr<ActionsSuggestions> actions_suggestions =
145 LoadTestModel(kModelFileName);
146 const ActionsSuggestionsResponse response =
147 actions_suggestions->SuggestActions(
148 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
149 /*reference_timezone=*/"Europe/Zurich",
150 /*annotations=*/{}, /*locales=*/"zz"}}});
151 EXPECT_THAT(response.actions, testing::IsEmpty());
152 }
153
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotations)154 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
155 std::unique_ptr<ActionsSuggestions> actions_suggestions =
156 LoadTestModel(kModelFileName);
157 AnnotatedSpan annotation;
158 annotation.span = {11, 15};
159 annotation.classification = {ClassificationResult("address", 1.0)};
160 const ActionsSuggestionsResponse response =
161 actions_suggestions->SuggestActions(
162 {{{/*user_id=*/1, "are you at home?",
163 /*reference_time_ms_utc=*/0,
164 /*reference_timezone=*/"Europe/Zurich",
165 /*annotations=*/{annotation},
166 /*locales=*/"en"}}});
167 ASSERT_GE(response.actions.size(), 1);
168 EXPECT_EQ(response.actions.front().type, "view_map");
169 EXPECT_EQ(response.actions.front().score, 1.0);
170 }
171
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithEntityData)172 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
173 const std::string actions_model_string =
174 ReadFile(GetModelPath() + kModelFileName);
175 std::unique_ptr<ActionsModelT> actions_model =
176 UnPackActionsModel(actions_model_string.c_str());
177 SetTestEntityDataSchema(actions_model.get());
178
179 // Set custom actions from annotations config.
180 actions_model->annotation_actions_spec->annotation_mapping.clear();
181 actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
182 new AnnotationActionsSpec_::AnnotationMappingT);
183 AnnotationActionsSpec_::AnnotationMappingT* mapping =
184 actions_model->annotation_actions_spec->annotation_mapping.back().get();
185 mapping->annotation_collection = "address";
186 mapping->action.reset(new ActionSuggestionSpecT);
187 mapping->action->type = "save_location";
188 mapping->action->score = 1.0;
189 mapping->action->priority_score = 2.0;
190 mapping->entity_field.reset(new FlatbufferFieldPathT);
191 mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
192 mapping->entity_field->field.back()->field_name = "location";
193
194 flatbuffers::FlatBufferBuilder builder;
195 FinishActionsModelBuffer(builder,
196 ActionsModel::Pack(builder, actions_model.get()));
197 std::unique_ptr<ActionsSuggestions> actions_suggestions =
198 ActionsSuggestions::FromUnownedBuffer(
199 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
200 builder.GetSize(), unilib_.get());
201
202 AnnotatedSpan annotation;
203 annotation.span = {11, 15};
204 annotation.classification = {ClassificationResult("address", 1.0)};
205 const ActionsSuggestionsResponse response =
206 actions_suggestions->SuggestActions(
207 {{{/*user_id=*/1, "are you at home?",
208 /*reference_time_ms_utc=*/0,
209 /*reference_timezone=*/"Europe/Zurich",
210 /*annotations=*/{annotation},
211 /*locales=*/"en"}}});
212 ASSERT_GE(response.actions.size(), 1);
213 EXPECT_EQ(response.actions.front().type, "save_location");
214 EXPECT_EQ(response.actions.front().score, 1.0);
215
216 // Check that the `location` entity field holds the text from the address
217 // annotation.
218 const flatbuffers::Table* entity =
219 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
220 response.actions.front().serialized_entity_data.data()));
221 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
222 "home");
223 }
224
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithNormalization)225 TEST_F(ActionsSuggestionsTest,
226 SuggestsActionsFromAnnotationsWithNormalization) {
227 const std::string actions_model_string =
228 ReadFile(GetModelPath() + kModelFileName);
229 std::unique_ptr<ActionsModelT> actions_model =
230 UnPackActionsModel(actions_model_string.c_str());
231 SetTestEntityDataSchema(actions_model.get());
232
233 // Set custom actions from annotations config.
234 actions_model->annotation_actions_spec->annotation_mapping.clear();
235 actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
236 new AnnotationActionsSpec_::AnnotationMappingT);
237 AnnotationActionsSpec_::AnnotationMappingT* mapping =
238 actions_model->annotation_actions_spec->annotation_mapping.back().get();
239 mapping->annotation_collection = "address";
240 mapping->action.reset(new ActionSuggestionSpecT);
241 mapping->action->type = "save_location";
242 mapping->action->score = 1.0;
243 mapping->action->priority_score = 2.0;
244 mapping->entity_field.reset(new FlatbufferFieldPathT);
245 mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
246 mapping->entity_field->field.back()->field_name = "location";
247 mapping->normalization_options.reset(new NormalizationOptionsT);
248 mapping->normalization_options->codepointwise_normalization =
249 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
250
251 flatbuffers::FlatBufferBuilder builder;
252 FinishActionsModelBuffer(builder,
253 ActionsModel::Pack(builder, actions_model.get()));
254 std::unique_ptr<ActionsSuggestions> actions_suggestions =
255 ActionsSuggestions::FromUnownedBuffer(
256 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
257 builder.GetSize(), unilib_.get());
258
259 AnnotatedSpan annotation;
260 annotation.span = {11, 15};
261 annotation.classification = {ClassificationResult("address", 1.0)};
262 const ActionsSuggestionsResponse response =
263 actions_suggestions->SuggestActions(
264 {{{/*user_id=*/1, "are you at home?",
265 /*reference_time_ms_utc=*/0,
266 /*reference_timezone=*/"Europe/Zurich",
267 /*annotations=*/{annotation},
268 /*locales=*/"en"}}});
269 ASSERT_GE(response.actions.size(), 1);
270 EXPECT_EQ(response.actions.front().type, "save_location");
271 EXPECT_EQ(response.actions.front().score, 1.0);
272
273 // Check that the `location` entity field holds the normalized text of the
274 // annotation.
275 const flatbuffers::Table* entity =
276 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
277 response.actions.front().serialized_entity_data.data()));
278 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
279 "HOME");
280 }
281
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromDuplicatedAnnotations)282 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
283 std::unique_ptr<ActionsSuggestions> actions_suggestions =
284 LoadTestModel(kModelFileName);
285 AnnotatedSpan flight_annotation;
286 flight_annotation.span = {11, 15};
287 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
288 AnnotatedSpan flight_annotation2;
289 flight_annotation2.span = {35, 39};
290 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
291 AnnotatedSpan email_annotation;
292 email_annotation.span = {43, 56};
293 email_annotation.classification = {ClassificationResult("email", 2.0)};
294
295 const ActionsSuggestionsResponse response =
296 actions_suggestions->SuggestActions(
297 {{{/*user_id=*/1,
298 "call me at LX38 or send message to LX38 or [email protected].",
299 /*reference_time_ms_utc=*/0,
300 /*reference_timezone=*/"Europe/Zurich",
301 /*annotations=*/
302 {flight_annotation, flight_annotation2, email_annotation},
303 /*locales=*/"en"}}});
304
305 ASSERT_GE(response.actions.size(), 2);
306 EXPECT_EQ(response.actions[0].type, "track_flight");
307 EXPECT_EQ(response.actions[0].score, 3.0);
308 EXPECT_EQ(response.actions[1].type, "send_email");
309 EXPECT_EQ(response.actions[1].score, 2.0);
310 }
311
TEST_F(ActionsSuggestionsTest,SuggestsActionsAnnotationsWithNoDeduplication)312 TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
313 const std::string actions_model_string =
314 ReadFile(GetModelPath() + kModelFileName);
315 std::unique_ptr<ActionsModelT> actions_model =
316 UnPackActionsModel(actions_model_string.c_str());
317 // Disable deduplication.
318 actions_model->annotation_actions_spec->deduplicate_annotations = false;
319 flatbuffers::FlatBufferBuilder builder;
320 FinishActionsModelBuffer(builder,
321 ActionsModel::Pack(builder, actions_model.get()));
322 std::unique_ptr<ActionsSuggestions> actions_suggestions =
323 ActionsSuggestions::FromUnownedBuffer(
324 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
325 builder.GetSize(), unilib_.get());
326 AnnotatedSpan flight_annotation;
327 flight_annotation.span = {11, 15};
328 flight_annotation.classification = {ClassificationResult("flight", 2.5)};
329 AnnotatedSpan flight_annotation2;
330 flight_annotation2.span = {35, 39};
331 flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
332 AnnotatedSpan email_annotation;
333 email_annotation.span = {43, 56};
334 email_annotation.classification = {ClassificationResult("email", 2.0)};
335
336 const ActionsSuggestionsResponse response =
337 actions_suggestions->SuggestActions(
338 {{{/*user_id=*/1,
339 "call me at LX38 or send message to LX38 or [email protected].",
340 /*reference_time_ms_utc=*/0,
341 /*reference_timezone=*/"Europe/Zurich",
342 /*annotations=*/
343 {flight_annotation, flight_annotation2, email_annotation},
344 /*locales=*/"en"}}});
345
346 ASSERT_GE(response.actions.size(), 3);
347 EXPECT_EQ(response.actions[0].type, "track_flight");
348 EXPECT_EQ(response.actions[0].score, 3.0);
349 EXPECT_EQ(response.actions[1].type, "track_flight");
350 EXPECT_EQ(response.actions[1].score, 2.5);
351 EXPECT_EQ(response.actions[2].type, "send_email");
352 EXPECT_EQ(response.actions[2].score, 2.0);
353 }
354
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)355 ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
356 const std::function<void(ActionsModelT*)>& set_config_fn,
357 const UniLib* unilib = nullptr) {
358 const std::string actions_model_string =
359 ReadFile(GetModelPath() + kModelFileName);
360 std::unique_ptr<ActionsModelT> actions_model =
361 UnPackActionsModel(actions_model_string.c_str());
362
363 // Set custom config.
364 set_config_fn(actions_model.get());
365
366 // Disable smart reply for easier testing.
367 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
368
369 flatbuffers::FlatBufferBuilder builder;
370 FinishActionsModelBuffer(builder,
371 ActionsModel::Pack(builder, actions_model.get()));
372 std::unique_ptr<ActionsSuggestions> actions_suggestions =
373 ActionsSuggestions::FromUnownedBuffer(
374 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
375 builder.GetSize(), unilib);
376
377 AnnotatedSpan flight_annotation;
378 flight_annotation.span = {15, 19};
379 flight_annotation.classification = {ClassificationResult("flight", 2.0)};
380 AnnotatedSpan email_annotation;
381 email_annotation.span = {0, 16};
382 email_annotation.classification = {ClassificationResult("email", 1.0)};
383
384 return actions_suggestions->SuggestActions(
385 {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
386 "[email protected]",
387 /*reference_time_ms_utc=*/0,
388 /*reference_timezone=*/"Europe/Zurich",
389 /*annotations=*/
390 {email_annotation},
391 /*locales=*/"en"},
392 {/*user_id=*/2,
393 "[email protected]",
394 /*reference_time_ms_utc=*/0,
395 /*reference_timezone=*/"Europe/Zurich",
396 /*annotations=*/
397 {email_annotation},
398 /*locales=*/"en"},
399 {/*user_id=*/1,
400 "[email protected]",
401 /*reference_time_ms_utc=*/0,
402 /*reference_timezone=*/"Europe/Zurich",
403 /*annotations=*/
404 {email_annotation},
405 /*locales=*/"en"},
406 {/*user_id=*/1,
407 "I am on flight LX38.",
408 /*reference_time_ms_utc=*/0,
409 /*reference_timezone=*/"Europe/Zurich",
410 /*annotations=*/
411 {flight_annotation},
412 /*locales=*/"en"}}});
413 }
414
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastMessage)415 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
416 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
417 [](ActionsModelT* actions_model) {
418 actions_model->annotation_actions_spec->include_local_user_messages =
419 false;
420 actions_model->annotation_actions_spec->only_until_last_sent = true;
421 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
422 actions_model->annotation_actions_spec->max_history_from_last_person =
423 1;
424 },
425 unilib_.get());
426 EXPECT_THAT(response.actions, SizeIs(1));
427 EXPECT_EQ(response.actions[0].type, "track_flight");
428 }
429
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastPerson)430 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
431 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
432 [](ActionsModelT* actions_model) {
433 actions_model->annotation_actions_spec->include_local_user_messages =
434 false;
435 actions_model->annotation_actions_spec->only_until_last_sent = true;
436 actions_model->annotation_actions_spec->max_history_from_any_person = 1;
437 actions_model->annotation_actions_spec->max_history_from_last_person =
438 3;
439 },
440 unilib_.get());
441 EXPECT_THAT(response.actions, SizeIs(2));
442 EXPECT_EQ(response.actions[0].type, "track_flight");
443 EXPECT_EQ(response.actions[1].type, "send_email");
444 }
445
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAny)446 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
447 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
448 [](ActionsModelT* actions_model) {
449 actions_model->annotation_actions_spec->include_local_user_messages =
450 false;
451 actions_model->annotation_actions_spec->only_until_last_sent = true;
452 actions_model->annotation_actions_spec->max_history_from_any_person = 2;
453 actions_model->annotation_actions_spec->max_history_from_last_person =
454 1;
455 },
456 unilib_.get());
457 EXPECT_THAT(response.actions, SizeIs(2));
458 EXPECT_EQ(response.actions[0].type, "track_flight");
459 EXPECT_EQ(response.actions[1].type, "send_email");
460 }
461
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessages)462 TEST_F(ActionsSuggestionsTest,
463 SuggestsActionsWithAnnotationsFromAnyManyMessages) {
464 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
465 [](ActionsModelT* actions_model) {
466 actions_model->annotation_actions_spec->include_local_user_messages =
467 false;
468 actions_model->annotation_actions_spec->only_until_last_sent = true;
469 actions_model->annotation_actions_spec->max_history_from_any_person = 3;
470 actions_model->annotation_actions_spec->max_history_from_last_person =
471 1;
472 },
473 unilib_.get());
474 EXPECT_THAT(response.actions, SizeIs(3));
475 EXPECT_EQ(response.actions[0].type, "track_flight");
476 EXPECT_EQ(response.actions[1].type, "send_email");
477 EXPECT_EQ(response.actions[2].type, "send_email");
478 }
479
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)480 TEST_F(ActionsSuggestionsTest,
481 SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
482 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
483 [](ActionsModelT* actions_model) {
484 actions_model->annotation_actions_spec->include_local_user_messages =
485 false;
486 actions_model->annotation_actions_spec->only_until_last_sent = true;
487 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
488 actions_model->annotation_actions_spec->max_history_from_last_person =
489 1;
490 },
491 unilib_.get());
492 EXPECT_THAT(response.actions, SizeIs(3));
493 EXPECT_EQ(response.actions[0].type, "track_flight");
494 EXPECT_EQ(response.actions[1].type, "send_email");
495 EXPECT_EQ(response.actions[2].type, "send_email");
496 }
497
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)498 TEST_F(ActionsSuggestionsTest,
499 SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
500 const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
501 [](ActionsModelT* actions_model) {
502 actions_model->annotation_actions_spec->include_local_user_messages =
503 true;
504 actions_model->annotation_actions_spec->only_until_last_sent = false;
505 actions_model->annotation_actions_spec->max_history_from_any_person = 5;
506 actions_model->annotation_actions_spec->max_history_from_last_person =
507 1;
508 },
509 unilib_.get());
510 EXPECT_THAT(response.actions, SizeIs(4));
511 EXPECT_EQ(response.actions[0].type, "track_flight");
512 EXPECT_EQ(response.actions[1].type, "send_email");
513 EXPECT_EQ(response.actions[2].type, "send_email");
514 EXPECT_EQ(response.actions[3].type, "send_email");
515 }
516
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")517 void TestSuggestActionsWithThreshold(
518 const std::function<void(ActionsModelT*)>& set_value_fn,
519 const UniLib* unilib = nullptr, const int expected_size = 0,
520 const std::string& preconditions_overwrite = "") {
521 const std::string actions_model_string =
522 ReadFile(GetModelPath() + kModelFileName);
523 std::unique_ptr<ActionsModelT> actions_model =
524 UnPackActionsModel(actions_model_string.c_str());
525 set_value_fn(actions_model.get());
526 flatbuffers::FlatBufferBuilder builder;
527 FinishActionsModelBuffer(builder,
528 ActionsModel::Pack(builder, actions_model.get()));
529 std::unique_ptr<ActionsSuggestions> actions_suggestions =
530 ActionsSuggestions::FromUnownedBuffer(
531 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
532 builder.GetSize(), unilib, preconditions_overwrite);
533 ASSERT_TRUE(actions_suggestions);
534 const ActionsSuggestionsResponse response =
535 actions_suggestions->SuggestActions(
536 {{{/*user_id=*/1, "I have the low-ground. Where are you?",
537 /*reference_time_ms_utc=*/0,
538 /*reference_timezone=*/"Europe/Zurich",
539 /*annotations=*/{}, /*locales=*/"en"}}});
540 EXPECT_LE(response.actions.size(), expected_size);
541 }
542
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithTriggeringScore)543 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
544 TestSuggestActionsWithThreshold(
545 [](ActionsModelT* actions_model) {
546 actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
547 },
548 unilib_.get(),
549 /*expected_size=*/1 /*no smart reply, only actions*/
550 );
551 }
552
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinReplyScore)553 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
554 TestSuggestActionsWithThreshold(
555 [](ActionsModelT* actions_model) {
556 actions_model->preconditions->min_reply_score_threshold = 1.0;
557 },
558 unilib_.get(),
559 /*expected_size=*/1 /*no smart reply, only actions*/
560 );
561 }
562
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithSensitiveTopicScore)563 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
564 TestSuggestActionsWithThreshold(
565 [](ActionsModelT* actions_model) {
566 actions_model->preconditions->max_sensitive_topic_score = 0.0;
567 },
568 unilib_.get(),
569 /*expected_size=*/4 /* no sensitive prediction in test model*/);
570 }
571
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMaxInputLength)572 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
573 TestSuggestActionsWithThreshold(
574 [](ActionsModelT* actions_model) {
575 actions_model->preconditions->max_input_length = 0;
576 },
577 unilib_.get());
578 }
579
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinInputLength)580 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
581 TestSuggestActionsWithThreshold(
582 [](ActionsModelT* actions_model) {
583 actions_model->preconditions->min_input_length = 100;
584 },
585 unilib_.get());
586 }
587
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithPreconditionsOverwrite)588 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
589 TriggeringPreconditionsT preconditions_overwrite;
590 preconditions_overwrite.max_input_length = 0;
591 flatbuffers::FlatBufferBuilder builder;
592 builder.Finish(
593 TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
594 TestSuggestActionsWithThreshold(
595 // Keep model untouched.
596 [](ActionsModelT* actions_model) {}, unilib_.get(),
597 /*expected_size=*/0,
598 std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
599 builder.GetSize()));
600 }
601
602 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidence)603 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
604 TestSuggestActionsWithThreshold(
605 [](ActionsModelT* actions_model) {
606 actions_model->preconditions->suppress_on_low_confidence_input = true;
607 actions_model->low_confidence_rules.reset(new RulesModelT);
608 actions_model->low_confidence_rules->regex_rule.emplace_back(
609 new RulesModel_::RegexRuleT);
610 actions_model->low_confidence_rules->regex_rule.back()->pattern =
611 "low-ground";
612 },
613 unilib_.get());
614 }
615
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutput)616 TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
617 const std::string actions_model_string =
618 ReadFile(GetModelPath() + kModelFileName);
619 std::unique_ptr<ActionsModelT> actions_model =
620 UnPackActionsModel(actions_model_string.c_str());
621 // Add custom triggering rule.
622 actions_model->rules.reset(new RulesModelT());
623 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
624 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
625 rule->pattern = "^(?i:hello\\s(there))$";
626 {
627 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
628 new RulesModel_::RuleActionSpecT);
629 rule_action->action.reset(new ActionSuggestionSpecT);
630 rule_action->action->type = "text_reply";
631 rule_action->action->response_text = "General Desaster!";
632 rule_action->action->score = 1.0f;
633 rule_action->action->priority_score = 1.0f;
634 rule->actions.push_back(std::move(rule_action));
635 }
636 {
637 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
638 new RulesModel_::RuleActionSpecT);
639 rule_action->action.reset(new ActionSuggestionSpecT);
640 rule_action->action->type = "text_reply";
641 rule_action->action->response_text = "General Kenobi!";
642 rule_action->action->score = 1.0f;
643 rule_action->action->priority_score = 1.0f;
644 rule->actions.push_back(std::move(rule_action));
645 }
646
647 // Add input-output low confidence rule.
648 actions_model->preconditions->suppress_on_low_confidence_input = true;
649 actions_model->low_confidence_rules.reset(new RulesModelT);
650 actions_model->low_confidence_rules->regex_rule.emplace_back(
651 new RulesModel_::RegexRuleT);
652 actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
653 actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
654 "(?i:desaster)";
655
656 flatbuffers::FlatBufferBuilder builder;
657 FinishActionsModelBuffer(builder,
658 ActionsModel::Pack(builder, actions_model.get()));
659 std::unique_ptr<ActionsSuggestions> actions_suggestions =
660 ActionsSuggestions::FromUnownedBuffer(
661 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
662 builder.GetSize(), unilib_.get());
663 ASSERT_TRUE(actions_suggestions);
664 const ActionsSuggestionsResponse response =
665 actions_suggestions->SuggestActions(
666 {{{/*user_id=*/1, "hello there",
667 /*reference_time_ms_utc=*/0,
668 /*reference_timezone=*/"Europe/Zurich",
669 /*annotations=*/{}, /*locales=*/"en"}}});
670 ASSERT_GE(response.actions.size(), 1);
671 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
672 }
673
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutputOverwrite)674 TEST_F(ActionsSuggestionsTest,
675 SuggestsActionsLowConfidenceInputOutputOverwrite) {
676 const std::string actions_model_string =
677 ReadFile(GetModelPath() + kModelFileName);
678 std::unique_ptr<ActionsModelT> actions_model =
679 UnPackActionsModel(actions_model_string.c_str());
680 actions_model->low_confidence_rules.reset();
681
682 // Add custom triggering rule.
683 actions_model->rules.reset(new RulesModelT());
684 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
685 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
686 rule->pattern = "^(?i:hello\\s(there))$";
687 {
688 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
689 new RulesModel_::RuleActionSpecT);
690 rule_action->action.reset(new ActionSuggestionSpecT);
691 rule_action->action->type = "text_reply";
692 rule_action->action->response_text = "General Desaster!";
693 rule_action->action->score = 1.0f;
694 rule_action->action->priority_score = 1.0f;
695 rule->actions.push_back(std::move(rule_action));
696 }
697 {
698 std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
699 new RulesModel_::RuleActionSpecT);
700 rule_action->action.reset(new ActionSuggestionSpecT);
701 rule_action->action->type = "text_reply";
702 rule_action->action->response_text = "General Kenobi!";
703 rule_action->action->score = 1.0f;
704 rule_action->action->priority_score = 1.0f;
705 rule->actions.push_back(std::move(rule_action));
706 }
707
708 // Add custom triggering rule via overwrite.
709 actions_model->preconditions->low_confidence_rules.reset();
710 TriggeringPreconditionsT preconditions;
711 preconditions.suppress_on_low_confidence_input = true;
712 preconditions.low_confidence_rules.reset(new RulesModelT);
713 preconditions.low_confidence_rules->regex_rule.emplace_back(
714 new RulesModel_::RegexRuleT);
715 preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
716 preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
717 "(?i:desaster)";
718 flatbuffers::FlatBufferBuilder preconditions_builder;
719 preconditions_builder.Finish(
720 TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
721 std::string serialize_preconditions = std::string(
722 reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
723 preconditions_builder.GetSize());
724
725 flatbuffers::FlatBufferBuilder builder;
726 FinishActionsModelBuffer(builder,
727 ActionsModel::Pack(builder, actions_model.get()));
728 std::unique_ptr<ActionsSuggestions> actions_suggestions =
729 ActionsSuggestions::FromUnownedBuffer(
730 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
731 builder.GetSize(), unilib_.get(), serialize_preconditions);
732
733 ASSERT_TRUE(actions_suggestions);
734 const ActionsSuggestionsResponse response =
735 actions_suggestions->SuggestActions(
736 {{{/*user_id=*/1, "hello there",
737 /*reference_time_ms_utc=*/0,
738 /*reference_timezone=*/"Europe/Zurich",
739 /*annotations=*/{}, /*locales=*/"en"}}});
740 ASSERT_GE(response.actions.size(), 1);
741 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
742 }
743 #endif
744
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)745 TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
746 const std::string actions_model_string =
747 ReadFile(GetModelPath() + kModelFileName);
748 std::unique_ptr<ActionsModelT> actions_model =
749 UnPackActionsModel(actions_model_string.c_str());
750
751 // Don't test if no sensitivity score is produced
752 if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
753 return;
754 }
755
756 actions_model->preconditions->max_sensitive_topic_score = 0.0;
757 actions_model->preconditions->suppress_on_sensitive_topic = true;
758 flatbuffers::FlatBufferBuilder builder;
759 FinishActionsModelBuffer(builder,
760 ActionsModel::Pack(builder, actions_model.get()));
761 std::unique_ptr<ActionsSuggestions> actions_suggestions =
762 ActionsSuggestions::FromUnownedBuffer(
763 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
764 builder.GetSize(), unilib_.get());
765 AnnotatedSpan annotation;
766 annotation.span = {11, 15};
767 annotation.classification = {
768 ClassificationResult(Collections::Address(), 1.0)};
769 const ActionsSuggestionsResponse response =
770 actions_suggestions->SuggestActions(
771 {{{/*user_id=*/1, "are you at home?",
772 /*reference_time_ms_utc=*/0,
773 /*reference_timezone=*/"Europe/Zurich",
774 /*annotations=*/{annotation},
775 /*locales=*/"en"}}});
776 EXPECT_THAT(response.actions, testing::IsEmpty());
777 }
778
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithLongerConversation)779 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
780 const std::string actions_model_string =
781 ReadFile(GetModelPath() + kModelFileName);
782 std::unique_ptr<ActionsModelT> actions_model =
783 UnPackActionsModel(actions_model_string.c_str());
784
785 // Allow a larger conversation context.
786 actions_model->max_conversation_history_length = 10;
787
788 flatbuffers::FlatBufferBuilder builder;
789 FinishActionsModelBuffer(builder,
790 ActionsModel::Pack(builder, actions_model.get()));
791 std::unique_ptr<ActionsSuggestions> actions_suggestions =
792 ActionsSuggestions::FromUnownedBuffer(
793 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
794 builder.GetSize(), unilib_.get());
795 AnnotatedSpan annotation;
796 annotation.span = {11, 15};
797 annotation.classification = {
798 ClassificationResult(Collections::Address(), 1.0)};
799 const ActionsSuggestionsResponse response =
800 actions_suggestions->SuggestActions(
801 {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
802 /*reference_time_ms_utc=*/10000,
803 /*reference_timezone=*/"Europe/Zurich",
804 /*annotations=*/{}, /*locales=*/"en"},
805 {/*user_id=*/1, "good! are you at home?",
806 /*reference_time_ms_utc=*/15000,
807 /*reference_timezone=*/"Europe/Zurich",
808 /*annotations=*/{annotation},
809 /*locales=*/"en"}}});
810 ASSERT_GE(response.actions.size(), 1);
811 EXPECT_EQ(response.actions[0].type, "view_map");
812 EXPECT_EQ(response.actions[0].score, 1.0);
813 }
814
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromTF2MultiTaskModel)815 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
816 std::unique_ptr<ActionsSuggestions> actions_suggestions =
817 LoadTestModel(kMultiTaskTF2TestModelFileName);
818 const ActionsSuggestionsResponse response =
819 actions_suggestions->SuggestActions(
820 {{{/*user_id=*/1, "Hello how are you",
821 /*reference_time_ms_utc=*/0,
822 /*reference_timezone=*/"Europe/Zurich",
823 /*annotations=*/{},
824 /*locales=*/"en"}}});
825 EXPECT_EQ(response.actions.size(), 4);
826 EXPECT_EQ(response.actions[0].response_text, "Okay");
827 EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
828 EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
829 }
830
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromPhoneGrammarAnnotations)831 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
832 std::unique_ptr<ActionsSuggestions> actions_suggestions =
833 LoadTestModel(kModelGrammarFileName);
834 AnnotatedSpan annotation;
835 annotation.span = {11, 15};
836 annotation.classification = {ClassificationResult("phone", 0.0)};
837 const ActionsSuggestionsResponse response =
838 actions_suggestions->SuggestActions(
839 {{{/*user_id=*/1, "Contact us at: *1234",
840 /*reference_time_ms_utc=*/0,
841 /*reference_timezone=*/"Europe/Zurich",
842 /*annotations=*/{annotation},
843 /*locales=*/"en"}}});
844 ASSERT_GE(response.actions.size(), 1);
845 EXPECT_EQ(response.actions.front().type, "call_phone");
846 EXPECT_EQ(response.actions.front().score, 0.0);
847 EXPECT_EQ(response.actions.front().priority_score, 0.0);
848 EXPECT_EQ(response.actions.front().annotations.size(), 1);
849 EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
850 EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
851 }
852
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)853 TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
854 std::unique_ptr<ActionsSuggestions> actions_suggestions =
855 LoadTestModel(kModelFileName);
856 AnnotatedSpan annotation;
857 annotation.span = {8, 12};
858 annotation.classification = {
859 ClassificationResult(Collections::Flight(), 1.0)};
860
861 const ActionsSuggestionsResponse response =
862 actions_suggestions->SuggestActions(
863 {{{/*user_id=*/1, "I'm on LX38?",
864 /*reference_time_ms_utc=*/0,
865 /*reference_timezone=*/"Europe/Zurich",
866 /*annotations=*/{annotation},
867 /*locales=*/"en"}}});
868
869 ASSERT_GE(response.actions.size(), 2);
870 EXPECT_EQ(response.actions[0].type, "track_flight");
871 EXPECT_EQ(response.actions[0].score, 1.0);
872 EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
873 EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
874 EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
875 }
876
877 #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)878 TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
879 const std::string actions_model_string =
880 ReadFile(GetModelPath() + kModelFileName);
881 std::unique_ptr<ActionsModelT> actions_model =
882 UnPackActionsModel(actions_model_string.c_str());
883 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
884
885 actions_model->rules.reset(new RulesModelT());
886 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
887 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
888 rule->pattern = "^(?i:hello\\s(there))$";
889 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
890 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
891 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
892 action->type = "text_reply";
893 action->response_text = "General Kenobi!";
894 action->score = 1.0f;
895 action->priority_score = 1.0f;
896
897 // Set capturing groups for entity data.
898 rule->actions.back()->capturing_group.emplace_back(
899 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
900 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
901 rule->actions.back()->capturing_group.back().get();
902 greeting_group->group_id = 0;
903 greeting_group->entity_field.reset(new FlatbufferFieldPathT);
904 greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
905 greeting_group->entity_field->field.back()->field_name = "greeting";
906 rule->actions.back()->capturing_group.emplace_back(
907 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
908 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
909 rule->actions.back()->capturing_group.back().get();
910 location_group->group_id = 1;
911 location_group->entity_field.reset(new FlatbufferFieldPathT);
912 location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
913 location_group->entity_field->field.back()->field_name = "location";
914
915 // Set test entity data schema.
916 SetTestEntityDataSchema(actions_model.get());
917
918 // Use meta data to generate custom serialized entity data.
919 MutableFlatbufferBuilder entity_data_builder(
920 flatbuffers::GetRoot<reflection::Schema>(
921 actions_model->actions_entity_data_schema.data()));
922 std::unique_ptr<MutableFlatbuffer> entity_data =
923 entity_data_builder.NewRoot();
924 entity_data->Set("person", "Kenobi");
925 action->serialized_entity_data = entity_data->Serialize();
926
927 flatbuffers::FlatBufferBuilder builder;
928 FinishActionsModelBuffer(builder,
929 ActionsModel::Pack(builder, actions_model.get()));
930 std::unique_ptr<ActionsSuggestions> actions_suggestions =
931 ActionsSuggestions::FromUnownedBuffer(
932 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
933 builder.GetSize(), unilib_.get());
934
935 const ActionsSuggestionsResponse response =
936 actions_suggestions->SuggestActions(
937 {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
938 /*reference_timezone=*/"Europe/Zurich",
939 /*annotations=*/{}, /*locales=*/"en"}}});
940 EXPECT_GE(response.actions.size(), 1);
941 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
942
943 // Check entity data.
944 const flatbuffers::Table* entity =
945 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
946 response.actions[0].serialized_entity_data.data()));
947 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
948 "hello there");
949 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
950 "there");
951 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
952 "Kenobi");
953 }
954
TEST_F(ActionsSuggestionsTest,CreateActionsFromRulesWithNormalization)955 TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
956 const std::string actions_model_string =
957 ReadFile(GetModelPath() + kModelFileName);
958 std::unique_ptr<ActionsModelT> actions_model =
959 UnPackActionsModel(actions_model_string.c_str());
960 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
961
962 actions_model->rules.reset(new RulesModelT());
963 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
964 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
965 rule->pattern = "^(?i:hello\\sthere)$";
966 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
967 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
968 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
969 action->type = "text_reply";
970 action->response_text = "General Kenobi!";
971 action->score = 1.0f;
972 action->priority_score = 1.0f;
973
974 // Set capturing groups for entity data.
975 rule->actions.back()->capturing_group.emplace_back(
976 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
977 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
978 rule->actions.back()->capturing_group.back().get();
979 greeting_group->group_id = 0;
980 greeting_group->entity_field.reset(new FlatbufferFieldPathT);
981 greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
982 greeting_group->entity_field->field.back()->field_name = "greeting";
983 greeting_group->normalization_options.reset(new NormalizationOptionsT);
984 greeting_group->normalization_options->codepointwise_normalization =
985 NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
986 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
987
988 // Set test entity data schema.
989 SetTestEntityDataSchema(actions_model.get());
990
991 flatbuffers::FlatBufferBuilder builder;
992 FinishActionsModelBuffer(builder,
993 ActionsModel::Pack(builder, actions_model.get()));
994 std::unique_ptr<ActionsSuggestions> actions_suggestions =
995 ActionsSuggestions::FromUnownedBuffer(
996 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
997 builder.GetSize(), unilib_.get());
998
999 const ActionsSuggestionsResponse response =
1000 actions_suggestions->SuggestActions(
1001 {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
1002 /*reference_timezone=*/"Europe/Zurich",
1003 /*annotations=*/{}, /*locales=*/"en"}}});
1004 EXPECT_GE(response.actions.size(), 1);
1005 EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
1006
1007 // Check entity data.
1008 const flatbuffers::Table* entity =
1009 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
1010 response.actions[0].serialized_entity_data.data()));
1011 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
1012 "HELLOTHERE");
1013 }
1014
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)1015 TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
1016 const std::string actions_model_string =
1017 ReadFile(GetModelPath() + kModelFileName);
1018 std::unique_ptr<ActionsModelT> actions_model =
1019 UnPackActionsModel(actions_model_string.c_str());
1020 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1021
1022 actions_model->rules.reset(new RulesModelT());
1023 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1024 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1025 rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
1026 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1027
1028 // Set capturing groups for entity data.
1029 rule->actions.back()->capturing_group.emplace_back(
1030 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1031 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1032 rule->actions.back()->capturing_group.back().get();
1033 code_group->group_id = 1;
1034 code_group->text_reply.reset(new ActionSuggestionSpecT);
1035 code_group->text_reply->score = 1.0f;
1036 code_group->text_reply->priority_score = 1.0f;
1037 code_group->normalization_options.reset(new NormalizationOptionsT);
1038 code_group->normalization_options->codepointwise_normalization =
1039 NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
1040
1041 flatbuffers::FlatBufferBuilder builder;
1042 FinishActionsModelBuffer(builder,
1043 ActionsModel::Pack(builder, actions_model.get()));
1044 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1045 ActionsSuggestions::FromUnownedBuffer(
1046 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1047 builder.GetSize(), unilib_.get());
1048
1049 const ActionsSuggestionsResponse response =
1050 actions_suggestions->SuggestActions(
1051 {{{/*user_id=*/1,
1052 "visit test.com or reply STOP to cancel your subscription",
1053 /*reference_time_ms_utc=*/0,
1054 /*reference_timezone=*/"Europe/Zurich",
1055 /*annotations=*/{}, /*locales=*/"en"}}});
1056 EXPECT_GE(response.actions.size(), 1);
1057 EXPECT_EQ(response.actions[0].response_text, "stop");
1058 }
1059
TEST_F(ActionsSuggestionsTest,CreatesActionsFromGrammarRules)1060 TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
1061 const std::string actions_model_string =
1062 ReadFile(GetModelPath() + kModelFileName);
1063 std::unique_ptr<ActionsModelT> actions_model =
1064 UnPackActionsModel(actions_model_string.c_str());
1065 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1066
1067 actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1068
1069 // Set tokenizer options.
1070 RulesModel_::GrammarRulesT* action_grammar_rules =
1071 actions_model->rules->grammar_rules.get();
1072 action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1073 action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1074 action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1075 false;
1076
1077 // Setup test rules.
1078 action_grammar_rules->rules.reset(new grammar::RulesSetT);
1079 grammar::LocaleShardMap locale_shard_map =
1080 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1081 grammar::Rules rules(locale_shard_map);
1082 rules.Add(
1083 "<knock>", {"<^>", "ventura", "!?", "<$>"},
1084 /*callback=*/
1085 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1086 /*callback_param=*/0);
1087 rules.Finalize().Serialize(/*include_debug_information=*/false,
1088 action_grammar_rules->rules.get());
1089 action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1090 RulesModel_::RuleActionSpecT* actions_spec =
1091 action_grammar_rules->actions.back().get();
1092 actions_spec->action.reset(new ActionSuggestionSpecT);
1093 actions_spec->action->response_text = "Yes, Satan?";
1094 actions_spec->action->priority_score = 1.0;
1095 actions_spec->action->score = 1.0;
1096 actions_spec->action->type = "text_reply";
1097 action_grammar_rules->rule_match.emplace_back(
1098 new RulesModel_::GrammarRules_::RuleMatchT);
1099 action_grammar_rules->rule_match.back()->action_id.push_back(0);
1100
1101 flatbuffers::FlatBufferBuilder builder;
1102 FinishActionsModelBuffer(builder,
1103 ActionsModel::Pack(builder, actions_model.get()));
1104 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1105 ActionsSuggestions::FromUnownedBuffer(
1106 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1107 builder.GetSize(), unilib_.get());
1108
1109 const ActionsSuggestionsResponse response =
1110 actions_suggestions->SuggestActions(
1111 {{{/*user_id=*/1, "Ventura!",
1112 /*reference_time_ms_utc=*/0,
1113 /*reference_timezone=*/"Europe/Zurich",
1114 /*annotations=*/{}, /*locales=*/"en"}}});
1115
1116 EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
1117 }
1118
1119 #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest,CreatesActionsWithAnnotationsFromGrammarRules)1120 TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
1121 std::unique_ptr<Annotator> annotator =
1122 Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
1123 const std::string actions_model_string =
1124 ReadFile(GetModelPath() + kModelFileName);
1125 std::unique_ptr<ActionsModelT> actions_model =
1126 UnPackActionsModel(actions_model_string.c_str());
1127 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1128
1129 actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1130
1131 // Set tokenizer options.
1132 RulesModel_::GrammarRulesT* action_grammar_rules =
1133 actions_model->rules->grammar_rules.get();
1134 action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1135 action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1136 action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1137 false;
1138
1139 // Setup test rules.
1140 action_grammar_rules->rules.reset(new grammar::RulesSetT);
1141 grammar::LocaleShardMap locale_shard_map =
1142 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1143 grammar::Rules rules(locale_shard_map);
1144 rules.Add(
1145 "<event>", {"it", "is", "at", "<time>"},
1146 /*callback=*/
1147 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1148 /*callback_param=*/0);
1149 rules.BindAnnotation("<time>", "time");
1150 rules.AddAnnotation("datetime");
1151 rules.Finalize().Serialize(/*include_debug_information=*/false,
1152 action_grammar_rules->rules.get());
1153 action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1154 RulesModel_::RuleActionSpecT* actions_spec =
1155 action_grammar_rules->actions.back().get();
1156 actions_spec->action.reset(new ActionSuggestionSpecT);
1157 actions_spec->action->priority_score = 1.0;
1158 actions_spec->action->score = 1.0;
1159 actions_spec->action->type = "create_event";
1160 action_grammar_rules->rule_match.emplace_back(
1161 new RulesModel_::GrammarRules_::RuleMatchT);
1162 action_grammar_rules->rule_match.back()->action_id.push_back(0);
1163
1164 flatbuffers::FlatBufferBuilder builder;
1165 FinishActionsModelBuffer(builder,
1166 ActionsModel::Pack(builder, actions_model.get()));
1167 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1168 ActionsSuggestions::FromUnownedBuffer(
1169 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1170 builder.GetSize(), unilib_.get());
1171
1172 const ActionsSuggestionsResponse response =
1173 actions_suggestions->SuggestActions(
1174 {{{/*user_id=*/1, "it is at 10:30",
1175 /*reference_time_ms_utc=*/0,
1176 /*reference_timezone=*/"Europe/Zurich",
1177 /*annotations=*/{}, /*locales=*/"en"}}},
1178 annotator.get());
1179
1180 EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
1181 }
1182 #endif
1183
TEST_F(ActionsSuggestionsTest,DeduplicateActions)1184 TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
1185 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1186 LoadTestModel(kModelFileName);
1187 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1188 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1189 /*reference_timezone=*/"Europe/Zurich",
1190 /*annotations=*/{}, /*locales=*/"en"}}});
1191
1192 // Check that the location sharing model triggered.
1193 bool has_location_sharing_action = false;
1194 for (const ActionSuggestion& action : response.actions) {
1195 if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
1196 has_location_sharing_action = true;
1197 break;
1198 }
1199 }
1200 EXPECT_TRUE(has_location_sharing_action);
1201 const int num_actions = response.actions.size();
1202
1203 // Add custom rule for location sharing.
1204 const std::string actions_model_string =
1205 ReadFile(GetModelPath() + kModelFileName);
1206 std::unique_ptr<ActionsModelT> actions_model =
1207 UnPackActionsModel(actions_model_string.c_str());
1208 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1209
1210 actions_model->rules.reset(new RulesModelT());
1211 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1212 actions_model->rules->regex_rule.back()->pattern =
1213 "^(?i:where are you[.?]?)$";
1214 actions_model->rules->regex_rule.back()->actions.emplace_back(
1215 new RulesModel_::RuleActionSpecT);
1216 actions_model->rules->regex_rule.back()->actions.back()->action.reset(
1217 new ActionSuggestionSpecT);
1218 ActionSuggestionSpecT* action =
1219 actions_model->rules->regex_rule.back()->actions.back()->action.get();
1220 action->score = 1.0f;
1221 action->type = ActionsSuggestionsTypes::ShareLocation();
1222
1223 flatbuffers::FlatBufferBuilder builder;
1224 FinishActionsModelBuffer(builder,
1225 ActionsModel::Pack(builder, actions_model.get()));
1226 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1227 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1228 builder.GetSize(), unilib_.get());
1229
1230 response = actions_suggestions->SuggestActions(
1231 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1232 /*reference_timezone=*/"Europe/Zurich",
1233 /*annotations=*/{}, /*locales=*/"en"}}});
1234 EXPECT_THAT(response.actions, SizeIs(num_actions));
1235 }
1236
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)1237 TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
1238 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1239 LoadTestModel(kModelFileName);
1240 AnnotatedSpan annotation;
1241 annotation.span = {7, 11};
1242 annotation.classification = {
1243 ClassificationResult(Collections::Flight(), 1.0)};
1244 ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1245 {{{/*user_id=*/1, "I'm on LX38",
1246 /*reference_time_ms_utc=*/0,
1247 /*reference_timezone=*/"Europe/Zurich",
1248 /*annotations=*/{annotation},
1249 /*locales=*/"en"}}});
1250
1251 // Check that the phone actions are present.
1252 EXPECT_GE(response.actions.size(), 1);
1253 EXPECT_EQ(response.actions[0].type, "track_flight");
1254
1255 // Add custom rule.
1256 const std::string actions_model_string =
1257 ReadFile(GetModelPath() + kModelFileName);
1258 std::unique_ptr<ActionsModelT> actions_model =
1259 UnPackActionsModel(actions_model_string.c_str());
1260 ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1261
1262 actions_model->rules.reset(new RulesModelT());
1263 actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1264 RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1265 rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
1266 rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1267 rule->actions.back()->action.reset(new ActionSuggestionSpecT);
1268 ActionSuggestionSpecT* action = rule->actions.back()->action.get();
1269 action->score = 1.0f;
1270 action->priority_score = 2.0f;
1271 action->type = "test_code";
1272 rule->actions.back()->capturing_group.emplace_back(
1273 new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1274 RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1275 rule->actions.back()->capturing_group.back().get();
1276 code_group->group_id = 1;
1277 code_group->annotation_name = "code";
1278 code_group->annotation_type = "code";
1279
1280 flatbuffers::FlatBufferBuilder builder;
1281 FinishActionsModelBuffer(builder,
1282 ActionsModel::Pack(builder, actions_model.get()));
1283 actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1284 reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1285 builder.GetSize(), unilib_.get());
1286
1287 response = actions_suggestions->SuggestActions(
1288 {{{/*user_id=*/1, "I'm on LX38",
1289 /*reference_time_ms_utc=*/0,
1290 /*reference_timezone=*/"Europe/Zurich",
1291 /*annotations=*/{annotation},
1292 /*locales=*/"en"}}});
1293 EXPECT_GE(response.actions.size(), 1);
1294 EXPECT_EQ(response.actions[0].type, "test_code");
1295 }
1296 #endif
1297
TEST_F(ActionsSuggestionsTest,RanksActions)1298 TEST_F(ActionsSuggestionsTest, RanksActions) {
1299 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1300 LoadTestModel(kModelFileName);
1301 std::vector<AnnotatedSpan> annotations(2);
1302 annotations[0].span = {11, 15};
1303 annotations[0].classification = {ClassificationResult("address", 1.0)};
1304 annotations[1].span = {19, 23};
1305 annotations[1].classification = {ClassificationResult("address", 2.0)};
1306 const ActionsSuggestionsResponse response =
1307 actions_suggestions->SuggestActions(
1308 {{{/*user_id=*/1, "are you at home or work?",
1309 /*reference_time_ms_utc=*/0,
1310 /*reference_timezone=*/"Europe/Zurich",
1311 /*annotations=*/annotations,
1312 /*locales=*/"en"}}});
1313 EXPECT_GE(response.actions.size(), 2);
1314 EXPECT_EQ(response.actions[0].type, "view_map");
1315 EXPECT_EQ(response.actions[0].score, 2.0);
1316 EXPECT_EQ(response.actions[1].type, "view_map");
1317 EXPECT_EQ(response.actions[1].score, 1.0);
1318 }
1319
TEST_F(ActionsSuggestionsTest,VisitActionsModel)1320 TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
1321 EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
1322 [](const ActionsModel* model) {
1323 if (model == nullptr) {
1324 return false;
1325 }
1326 return true;
1327 }));
1328 EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
1329 [](const ActionsModel* model) {
1330 if (model == nullptr) {
1331 return false;
1332 }
1333 return true;
1334 }));
1335 }
1336
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithHashGramModel)1337 TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
1338 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1339 LoadHashGramTestModel();
1340 ASSERT_TRUE(actions_suggestions != nullptr);
1341 {
1342 const ActionsSuggestionsResponse response =
1343 actions_suggestions->SuggestActions(
1344 {{{/*user_id=*/1, "hello",
1345 /*reference_time_ms_utc=*/0,
1346 /*reference_timezone=*/"Europe/Zurich",
1347 /*annotations=*/{},
1348 /*locales=*/"en"}}});
1349 EXPECT_THAT(response.actions, testing::IsEmpty());
1350 }
1351 {
1352 const ActionsSuggestionsResponse response =
1353 actions_suggestions->SuggestActions(
1354 {{{/*user_id=*/1, "where are you",
1355 /*reference_time_ms_utc=*/0,
1356 /*reference_timezone=*/"Europe/Zurich",
1357 /*annotations=*/{},
1358 /*locales=*/"en"}}});
1359 EXPECT_THAT(
1360 response.actions,
1361 ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1362 }
1363 {
1364 const ActionsSuggestionsResponse response =
1365 actions_suggestions->SuggestActions(
1366 {{{/*user_id=*/1, "do you know johns number",
1367 /*reference_time_ms_utc=*/0,
1368 /*reference_timezone=*/"Europe/Zurich",
1369 /*annotations=*/{},
1370 /*locales=*/"en"}}});
1371 EXPECT_THAT(
1372 response.actions,
1373 ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1374 }
1375 }
1376
1377 // Test class to expose token embedding methods for testing.
1378 class TestingMessageEmbedder : private ActionsSuggestions {
1379 public:
1380 explicit TestingMessageEmbedder(const ActionsModel* model);
1381
1382 using ActionsSuggestions::EmbedAndFlattenTokens;
1383 using ActionsSuggestions::EmbedTokensPerMessage;
1384
1385 protected:
1386 // EmbeddingExecutor that always returns features based on
1387 // the id of the sparse features.
1388 class FakeEmbeddingExecutor : public EmbeddingExecutor {
1389 public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1390 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1391 const int dest_size) const override {
1392 TC3_CHECK_GE(dest_size, 1);
1393 EXPECT_EQ(sparse_features.size(), 1);
1394 dest[0] = sparse_features.data()[0];
1395 return true;
1396 }
1397 };
1398
1399 std::unique_ptr<UniLib> unilib_;
1400 };
1401
TestingMessageEmbedder(const ActionsModel * model)1402 TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
1403 : unilib_(CreateUniLibForTesting()) {
1404 model_ = model;
1405 const ActionsTokenFeatureProcessorOptions* options =
1406 model->feature_processor_options();
1407 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
1408 embedding_executor_.reset(new FakeEmbeddingExecutor());
1409 EXPECT_TRUE(
1410 EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1411 EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1412 EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1413 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1414 EXPECT_EQ(token_embedding_size_, 1);
1415 }
1416
1417 class EmbeddingTest : public testing::Test {
1418 protected:
EmbeddingTest()1419 explicit EmbeddingTest() {
1420 model_.feature_processor_options.reset(
1421 new ActionsTokenFeatureProcessorOptionsT);
1422 options_ = model_.feature_processor_options.get();
1423 options_->chargram_orders = {1};
1424 options_->num_buckets = 1000;
1425 options_->embedding_size = 1;
1426 options_->start_token_id = 0;
1427 options_->end_token_id = 1;
1428 options_->padding_token_id = 2;
1429 options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1430 }
1431
CreateTestingMessageEmbedder()1432 TestingMessageEmbedder CreateTestingMessageEmbedder() {
1433 flatbuffers::FlatBufferBuilder builder;
1434 FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1435 buffer_ = builder.Release();
1436 return TestingMessageEmbedder(
1437 flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1438 }
1439
1440 flatbuffers::DetachedBuffer buffer_;
1441 ActionsModelT model_;
1442 ActionsTokenFeatureProcessorOptionsT* options_;
1443 };
1444
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1445 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1446 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1447 std::vector<std::vector<Token>> tokens = {
1448 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1449 std::vector<float> embeddings;
1450 int max_num_tokens_per_message = 0;
1451
1452 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1453 &max_num_tokens_per_message));
1454
1455 EXPECT_EQ(max_num_tokens_per_message, 3);
1456 EXPECT_EQ(embeddings.size(), 3);
1457 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1458 options_->num_buckets));
1459 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1460 options_->num_buckets));
1461 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1462 options_->num_buckets));
1463 }
1464
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1465 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1466 options_->min_num_tokens_per_message = 5;
1467 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1468 std::vector<std::vector<Token>> tokens = {
1469 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1470 std::vector<float> embeddings;
1471 int max_num_tokens_per_message = 0;
1472
1473 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1474 &max_num_tokens_per_message));
1475
1476 EXPECT_EQ(max_num_tokens_per_message, 5);
1477 EXPECT_EQ(embeddings.size(), 5);
1478 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1479 options_->num_buckets));
1480 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1481 options_->num_buckets));
1482 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1483 options_->num_buckets));
1484 EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
1485 EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
1486 }
1487
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1488 TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1489 options_->max_num_tokens_per_message = 2;
1490 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1491 std::vector<std::vector<Token>> tokens = {
1492 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1493 std::vector<float> embeddings;
1494 int max_num_tokens_per_message = 0;
1495
1496 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1497 &max_num_tokens_per_message));
1498
1499 EXPECT_EQ(max_num_tokens_per_message, 2);
1500 EXPECT_EQ(embeddings.size(), 2);
1501 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1502 options_->num_buckets));
1503 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1504 options_->num_buckets));
1505 }
1506
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1507 TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1508 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1509 std::vector<std::vector<Token>> tokens = {
1510 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1511 {Token("d", 0, 1), Token("e", 2, 3)}};
1512 std::vector<float> embeddings;
1513 int max_num_tokens_per_message = 0;
1514
1515 EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1516 &max_num_tokens_per_message));
1517
1518 EXPECT_EQ(max_num_tokens_per_message, 3);
1519 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1520 options_->num_buckets));
1521 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1522 options_->num_buckets));
1523 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1524 options_->num_buckets));
1525 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1526 options_->num_buckets));
1527 EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1528 options_->num_buckets));
1529 EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1530 }
1531
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1532 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1533 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1534 std::vector<std::vector<Token>> tokens = {
1535 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1536 std::vector<float> embeddings;
1537 int total_token_count = 0;
1538
1539 EXPECT_TRUE(
1540 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1541
1542 EXPECT_EQ(total_token_count, 5);
1543 EXPECT_EQ(embeddings.size(), 5);
1544 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1545 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1546 options_->num_buckets));
1547 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1548 options_->num_buckets));
1549 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1550 options_->num_buckets));
1551 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1552 }
1553
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1554 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1555 options_->min_num_total_tokens = 7;
1556 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1557 std::vector<std::vector<Token>> tokens = {
1558 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1559 std::vector<float> embeddings;
1560 int total_token_count = 0;
1561
1562 EXPECT_TRUE(
1563 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1564
1565 EXPECT_EQ(total_token_count, 7);
1566 EXPECT_EQ(embeddings.size(), 7);
1567 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1568 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1569 options_->num_buckets));
1570 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1571 options_->num_buckets));
1572 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1573 options_->num_buckets));
1574 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1575 EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1576 EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
1577 }
1578
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1579 TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1580 options_->max_num_total_tokens = 3;
1581 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1582 std::vector<std::vector<Token>> tokens = {
1583 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1584 std::vector<float> embeddings;
1585 int total_token_count = 0;
1586
1587 EXPECT_TRUE(
1588 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1589
1590 EXPECT_EQ(total_token_count, 3);
1591 EXPECT_EQ(embeddings.size(), 3);
1592 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1593 options_->num_buckets));
1594 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1595 options_->num_buckets));
1596 EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
1597 }
1598
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1599 TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1600 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1601 std::vector<std::vector<Token>> tokens = {
1602 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1603 {Token("d", 0, 1), Token("e", 2, 3)}};
1604 std::vector<float> embeddings;
1605 int total_token_count = 0;
1606
1607 EXPECT_TRUE(
1608 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1609
1610 EXPECT_EQ(total_token_count, 9);
1611 EXPECT_EQ(embeddings.size(), 9);
1612 EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1613 EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1614 options_->num_buckets));
1615 EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1616 options_->num_buckets));
1617 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1618 options_->num_buckets));
1619 EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1620 EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
1621 EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1622 options_->num_buckets));
1623 EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1624 options_->num_buckets));
1625 EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
1626 }
1627
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1628 TEST_F(EmbeddingTest,
1629 EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1630 options_->max_num_total_tokens = 7;
1631 const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1632 std::vector<std::vector<Token>> tokens = {
1633 {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1634 {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1635 std::vector<float> embeddings;
1636 int total_token_count = 0;
1637
1638 EXPECT_TRUE(
1639 embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1640
1641 EXPECT_EQ(total_token_count, 7);
1642 EXPECT_EQ(embeddings.size(), 7);
1643 EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1644 options_->num_buckets));
1645 EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
1646 EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
1647 EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1648 options_->num_buckets));
1649 EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1650 options_->num_buckets));
1651 EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1652 options_->num_buckets));
1653 EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
1654 }
1655
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDefault)1656 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
1657 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1658 LoadMultiTaskTestModel();
1659 const ActionsSuggestionsResponse response =
1660 actions_suggestions->SuggestActions(
1661 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1662 /*reference_timezone=*/"Europe/Zurich",
1663 /*annotations=*/{}, /*locales=*/"en"}}});
1664 EXPECT_EQ(response.actions.size(),
1665 11 /* 8 binary classification + 3 smart replies*/);
1666 }
1667
1668 const float kDisableThresholdVal = 2.0;
1669
1670 constexpr char kSpamThreshold[] = "spam_confidence_threshold";
1671 constexpr char kLocationThreshold[] = "location_confidence_threshold";
1672 constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
1673 constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
1674 constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
1675 constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
1676 constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
1677 constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
1678 constexpr char kDiversificationParm[] = "diversification_distance_threshold";
1679 constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
1680
GetOptionsToDisableAllClassification()1681 ActionSuggestionOptions GetOptionsToDisableAllClassification() {
1682 ActionSuggestionOptions options;
1683 // Disable all classification heads.
1684 options.model_parameters.insert(
1685 {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1686 options.model_parameters.insert(
1687 {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1688 options.model_parameters.insert(
1689 {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1690 options.model_parameters.insert(
1691 {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1692 options.model_parameters.insert(
1693 {kRestaurantsThreshold,
1694 libtextclassifier3::Variant(kDisableThresholdVal)});
1695 options.model_parameters.insert(
1696 {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1697 options.model_parameters.insert(
1698 {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1699 options.model_parameters.insert(
1700 {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1701 return options;
1702 }
1703
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyOnly)1704 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
1705 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1706 LoadMultiTaskTestModel();
1707 const ActionSuggestionOptions options =
1708 GetOptionsToDisableAllClassification();
1709 const ActionsSuggestionsResponse response =
1710 actions_suggestions->SuggestActions(
1711 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1712 /*reference_timezone=*/"Europe/Zurich",
1713 /*annotations=*/{}, /*locales=*/"en"}}},
1714 /*annotator=*/nullptr, options);
1715 EXPECT_THAT(response.actions,
1716 ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
1717 IsSmartReply("I'm home")));
1718 EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1719 }
1720
1721 const int kUserProfileSize = 1000;
1722 constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
1723 constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
1724
GetOptionsForSmartReplyP13nModel()1725 ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
1726 ActionSuggestionOptions options;
1727 const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
1728 const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
1729 options.model_parameters.insert(
1730 {kUserProfileTokenIndex,
1731 libtextclassifier3::Variant(user_profile_token_indexes)});
1732 options.model_parameters.insert(
1733 {kUserProfileTokenWeight,
1734 libtextclassifier3::Variant(user_profile_token_weights)});
1735 return options;
1736 }
1737
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyP13n)1738 TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
1739 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1740 LoadMultiTaskSrP13nTestModel();
1741 const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
1742 const ActionsSuggestionsResponse response =
1743 actions_suggestions->SuggestActions(
1744 {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
1745 /*reference_timezone=*/"Europe/Zurich",
1746 /*annotations=*/{}, /*locales=*/"en"}}},
1747 /*annotator=*/nullptr, options);
1748 EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1749 }
1750
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation)1751 TEST_F(ActionsSuggestionsTest,
1752 MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
1753 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1754 LoadMultiTaskTestModel();
1755 ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1756 options.model_parameters[kLocationThreshold] =
1757 libtextclassifier3::Variant(0.35f);
1758 options.model_parameters.insert(
1759 {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
1760 const ActionsSuggestionsResponse response =
1761 actions_suggestions->SuggestActions(
1762 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1763 /*reference_timezone=*/"Europe/Zurich",
1764 /*annotations=*/{}, /*locales=*/"en"}}},
1765 /*annotator=*/nullptr, options);
1766 EXPECT_THAT(
1767 response.actions,
1768 ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
1769 IsSmartReply("Yes"), IsSmartReply("")));
1770 EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
1771 }
1772
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder)1773 TEST_F(ActionsSuggestionsTest,
1774 MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
1775 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1776 LoadMultiTaskTestModel();
1777 ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1778 options.model_parameters[kLocationThreshold] =
1779 libtextclassifier3::Variant(0.35f);
1780 // reminder head always trigger since the threshold is zero.
1781 options.model_parameters[kReminderThreshold] =
1782 libtextclassifier3::Variant(0.0f);
1783 options.model_parameters.insert(
1784 {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
1785 const ActionsSuggestionsResponse response =
1786 actions_suggestions->SuggestActions(
1787 {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1788 /*reference_timezone=*/"Europe/Zurich",
1789 /*annotations=*/{}, /*locales=*/"en"}}},
1790 /*annotator=*/nullptr, options);
1791 EXPECT_THAT(
1792 response.actions,
1793 ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
1794 IsSmartReply("Yes"),
1795 /*Different emoji than previous test*/ IsSmartReply(""),
1796 IsActionOfType("REMINDER_INTENT")));
1797 EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
1798 }
1799
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromMultiTaskSrEmojiModel)1800 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
1801 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1802 LoadTestModel(kMultiTaskSrEmojiModelFileName);
1803
1804 const ActionsSuggestionsResponse response =
1805 actions_suggestions->SuggestActions(
1806 {{{/*user_id=*/1, "hello?",
1807 /*reference_time_ms_utc=*/0,
1808 /*reference_timezone=*/"Europe/Zurich",
1809 /*annotations=*/{},
1810 /*locales=*/"en"}}});
1811 EXPECT_EQ(response.actions.size(), 5);
1812 EXPECT_EQ(response.actions[0].response_text, "");
1813 EXPECT_EQ(response.actions[0].type, "text_reply");
1814 EXPECT_EQ(response.actions[1].response_text, "");
1815 EXPECT_EQ(response.actions[1].type, "text_reply");
1816 EXPECT_EQ(response.actions[2].response_text, "Yes");
1817 EXPECT_EQ(response.actions[2].type, "text_reply");
1818 }
1819
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelRemovesTextHeadEmoji)1820 TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
1821 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1822 LoadTestModel(kMultiTaskSrEmojiModelFileName);
1823
1824 const ActionsSuggestionsResponse response =
1825 actions_suggestions->SuggestActions(
1826 {{{/*user_id=*/1, "a pleasure chatting",
1827 /*reference_time_ms_utc=*/0,
1828 /*reference_timezone=*/"Europe/Zurich",
1829 /*annotations=*/{},
1830 /*locales=*/"en"}}});
1831 EXPECT_EQ(response.actions.size(), 3);
1832 EXPECT_EQ(response.actions[0].response_text, "");
1833 EXPECT_EQ(response.actions[0].type, "text_reply");
1834 EXPECT_EQ(response.actions[1].response_text, "");
1835 EXPECT_EQ(response.actions[1].type, "text_reply");
1836 EXPECT_EQ(response.actions[2].response_text, "Okay");
1837 EXPECT_EQ(response.actions[2].type, "text_reply");
1838 }
1839
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelUsesConcepts)1840 TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
1841 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1842 LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
1843
1844 const ActionsSuggestionsResponse response =
1845 actions_suggestions->SuggestActions(
1846 {{{/*user_id=*/1, "i am tired",
1847 /*reference_time_ms_utc=*/0,
1848 /*reference_timezone=*/"Europe/Zurich",
1849 /*annotations=*/{},
1850 /*locales=*/"en"}}});
1851 std::vector<std::string> sigh_emojis = {"", ""};
1852
1853 EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
1854 response.actions[0].response_text) !=
1855 sigh_emojis.end());
1856 EXPECT_EQ(response.actions[0].type, "emoji_reply");
1857 }
1858
TEST_F(ActionsSuggestionsTest,LiveRelayModel)1859 TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
1860 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1861 LoadTestModel(kLiveRelayTFLiteModelFileName);
1862 const ActionsSuggestionsResponse response =
1863 actions_suggestions->SuggestActions(
1864 {{{/*user_id=*/1, "Hi",
1865 /*reference_time_ms_utc=*/0,
1866 /*reference_timezone=*/"Europe/Zurich",
1867 /*annotations=*/{},
1868 /*locales=*/"en"}}});
1869 EXPECT_EQ(response.actions.size(), 3);
1870 EXPECT_EQ(response.actions[0].response_text, "Hi how are you doing");
1871 EXPECT_EQ(response.actions[0].type, "text_reply");
1872 EXPECT_EQ(response.actions[1].response_text, "Hi whats up");
1873 EXPECT_EQ(response.actions[1].type, "text_reply");
1874 }
1875
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromSensitiveTfLiteModel)1876 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
1877 std::unique_ptr<ActionsSuggestions> actions_suggestions =
1878 LoadTestModel(kSensitiveTFliteModelFileName);
1879 const ActionsSuggestionsResponse response =
1880 actions_suggestions->SuggestActions(
1881 {{{/*user_id=*/1, "I want to kill myself",
1882 /*reference_time_ms_utc=*/0,
1883 /*reference_timezone=*/"Europe/Zurich",
1884 /*annotations=*/{},
1885 /*locales=*/"en"}}});
1886 EXPECT_EQ(response.actions.size(), 0);
1887 EXPECT_TRUE(response.is_sensitive);
1888 EXPECT_FALSE(response.output_filtered_low_confidence);
1889 }
1890
1891 } // namespace
1892 } // namespace libtextclassifier3
1893