1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/grammar-actions.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <iostream>
20*993b0882SAndroid Build Coastguard Worker #include <memory>
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
23*993b0882SAndroid Build Coastguard Worker #include "actions/test-utils.h"
24*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/flatbuffers.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/rules_generated.h"
28*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/types.h"
29*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/utils/rules.h"
30*993b0882SAndroid Build Coastguard Worker #include "utils/jvm-test-utils.h"
31*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
32*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
33*993b0882SAndroid Build Coastguard Worker
34*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
35*993b0882SAndroid Build Coastguard Worker namespace {
36*993b0882SAndroid Build Coastguard Worker
37*993b0882SAndroid Build Coastguard Worker using ::testing::ElementsAre;
38*993b0882SAndroid Build Coastguard Worker using ::testing::IsEmpty;
39*993b0882SAndroid Build Coastguard Worker
40*993b0882SAndroid Build Coastguard Worker using ::libtextclassifier3::grammar::LocaleShardMap;
41*993b0882SAndroid Build Coastguard Worker
42*993b0882SAndroid Build Coastguard Worker class TestGrammarActions : public GrammarActions {
43*993b0882SAndroid Build Coastguard Worker public:
TestGrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder=nullptr)44*993b0882SAndroid Build Coastguard Worker explicit TestGrammarActions(
45*993b0882SAndroid Build Coastguard Worker const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
46*993b0882SAndroid Build Coastguard Worker const MutableFlatbufferBuilder* entity_data_builder = nullptr)
47*993b0882SAndroid Build Coastguard Worker : GrammarActions(unilib, grammar_rules, entity_data_builder,
48*993b0882SAndroid Build Coastguard Worker
49*993b0882SAndroid Build Coastguard Worker /*smart_reply_action_type=*/"text_reply") {}
50*993b0882SAndroid Build Coastguard Worker };
51*993b0882SAndroid Build Coastguard Worker
52*993b0882SAndroid Build Coastguard Worker class GrammarActionsTest : public testing::Test {
53*993b0882SAndroid Build Coastguard Worker protected:
54*993b0882SAndroid Build Coastguard Worker struct AnnotationSpec {
55*993b0882SAndroid Build Coastguard Worker int group_id = 0;
56*993b0882SAndroid Build Coastguard Worker std::string annotation_name = "";
57*993b0882SAndroid Build Coastguard Worker bool use_annotation_match = false;
58*993b0882SAndroid Build Coastguard Worker };
59*993b0882SAndroid Build Coastguard Worker
GrammarActionsTest()60*993b0882SAndroid Build Coastguard Worker GrammarActionsTest()
61*993b0882SAndroid Build Coastguard Worker : unilib_(CreateUniLibForTesting()),
62*993b0882SAndroid Build Coastguard Worker serialized_entity_data_schema_(TestEntityDataSchema()),
63*993b0882SAndroid Build Coastguard Worker entity_data_builder_(new MutableFlatbufferBuilder(
64*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<reflection::Schema>(
65*993b0882SAndroid Build Coastguard Worker serialized_entity_data_schema_.data()))) {}
66*993b0882SAndroid Build Coastguard Worker
SetTokenizerOptions(RulesModel_::GrammarRulesT * action_grammar_rules) const67*993b0882SAndroid Build Coastguard Worker void SetTokenizerOptions(
68*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT* action_grammar_rules) const {
69*993b0882SAndroid Build Coastguard Worker action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
70*993b0882SAndroid Build Coastguard Worker action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
71*993b0882SAndroid Build Coastguard Worker action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
72*993b0882SAndroid Build Coastguard Worker false;
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker
AddActionSpec(const std::string & type,const std::string & response_text,const std::vector<AnnotationSpec> & annotations,RulesModel_::GrammarRulesT * action_grammar_rules) const75*993b0882SAndroid Build Coastguard Worker int AddActionSpec(const std::string& type, const std::string& response_text,
76*993b0882SAndroid Build Coastguard Worker const std::vector<AnnotationSpec>& annotations,
77*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT* action_grammar_rules) const {
78*993b0882SAndroid Build Coastguard Worker const int action_id = action_grammar_rules->actions.size();
79*993b0882SAndroid Build Coastguard Worker action_grammar_rules->actions.emplace_back(
80*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpecT);
81*993b0882SAndroid Build Coastguard Worker RulesModel_::RuleActionSpecT* actions_spec =
82*993b0882SAndroid Build Coastguard Worker action_grammar_rules->actions.back().get();
83*993b0882SAndroid Build Coastguard Worker actions_spec->action.reset(new ActionSuggestionSpecT);
84*993b0882SAndroid Build Coastguard Worker actions_spec->action->response_text = response_text;
85*993b0882SAndroid Build Coastguard Worker actions_spec->action->priority_score = 1.0;
86*993b0882SAndroid Build Coastguard Worker actions_spec->action->score = 1.0;
87*993b0882SAndroid Build Coastguard Worker actions_spec->action->type = type;
88*993b0882SAndroid Build Coastguard Worker // Create annotations for specified capturing groups.
89*993b0882SAndroid Build Coastguard Worker for (const AnnotationSpec& annotation : annotations) {
90*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.emplace_back(
91*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
92*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->group_id = annotation.group_id;
93*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->annotation_name =
94*993b0882SAndroid Build Coastguard Worker annotation.annotation_name;
95*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->annotation_type =
96*993b0882SAndroid Build Coastguard Worker annotation.annotation_name;
97*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->use_annotation_match =
98*993b0882SAndroid Build Coastguard Worker annotation.use_annotation_match;
99*993b0882SAndroid Build Coastguard Worker }
100*993b0882SAndroid Build Coastguard Worker
101*993b0882SAndroid Build Coastguard Worker return action_id;
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker
AddSmartReplySpec(const std::string & response_text,RulesModel_::GrammarRulesT * action_grammar_rules) const104*993b0882SAndroid Build Coastguard Worker int AddSmartReplySpec(
105*993b0882SAndroid Build Coastguard Worker const std::string& response_text,
106*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT* action_grammar_rules) const {
107*993b0882SAndroid Build Coastguard Worker return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
108*993b0882SAndroid Build Coastguard Worker }
109*993b0882SAndroid Build Coastguard Worker
AddCapturingMatchSmartReplySpec(const int match_id,RulesModel_::GrammarRulesT * action_grammar_rules) const110*993b0882SAndroid Build Coastguard Worker int AddCapturingMatchSmartReplySpec(
111*993b0882SAndroid Build Coastguard Worker const int match_id,
112*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT* action_grammar_rules) const {
113*993b0882SAndroid Build Coastguard Worker const int action_id = action_grammar_rules->actions.size();
114*993b0882SAndroid Build Coastguard Worker action_grammar_rules->actions.emplace_back(
115*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpecT);
116*993b0882SAndroid Build Coastguard Worker RulesModel_::RuleActionSpecT* actions_spec =
117*993b0882SAndroid Build Coastguard Worker action_grammar_rules->actions.back().get();
118*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.emplace_back(
119*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
120*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->group_id = match_id;
121*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->text_reply.reset(
122*993b0882SAndroid Build Coastguard Worker new ActionSuggestionSpecT);
123*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
124*993b0882SAndroid Build Coastguard Worker actions_spec->capturing_group.back()->text_reply->score = 1.0;
125*993b0882SAndroid Build Coastguard Worker return action_id;
126*993b0882SAndroid Build Coastguard Worker }
127*993b0882SAndroid Build Coastguard Worker
AddRuleMatch(const std::vector<int> & action_ids,RulesModel_::GrammarRulesT * action_grammar_rules) const128*993b0882SAndroid Build Coastguard Worker int AddRuleMatch(const std::vector<int>& action_ids,
129*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT* action_grammar_rules) const {
130*993b0882SAndroid Build Coastguard Worker const int rule_match_id = action_grammar_rules->rule_match.size();
131*993b0882SAndroid Build Coastguard Worker action_grammar_rules->rule_match.emplace_back(
132*993b0882SAndroid Build Coastguard Worker new RulesModel_::GrammarRules_::RuleMatchT);
133*993b0882SAndroid Build Coastguard Worker action_grammar_rules->rule_match.back()->action_id.insert(
134*993b0882SAndroid Build Coastguard Worker action_grammar_rules->rule_match.back()->action_id.end(),
135*993b0882SAndroid Build Coastguard Worker action_ids.begin(), action_ids.end());
136*993b0882SAndroid Build Coastguard Worker return rule_match_id;
137*993b0882SAndroid Build Coastguard Worker }
138*993b0882SAndroid Build Coastguard Worker
139*993b0882SAndroid Build Coastguard Worker std::unique_ptr<UniLib> unilib_;
140*993b0882SAndroid Build Coastguard Worker const std::string serialized_entity_data_schema_;
141*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
142*993b0882SAndroid Build Coastguard Worker };
143*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,ProducesSmartReplies)144*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, ProducesSmartReplies) {
145*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
146*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
147*993b0882SAndroid Build Coastguard Worker
148*993b0882SAndroid Build Coastguard Worker // Create test rules.
149*993b0882SAndroid Build Coastguard Worker // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
150*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
151*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
152*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
153*993b0882SAndroid Build Coastguard Worker rules.Add(
154*993b0882SAndroid Build Coastguard Worker "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
155*993b0882SAndroid Build Coastguard Worker /*callback=*/
156*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
157*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
158*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
159*993b0882SAndroid Build Coastguard Worker AddSmartReplySpec("Yes?", &action_grammar_rules)},
160*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
161*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
162*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
163*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
164*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
165*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
166*993b0882SAndroid Build Coastguard Worker
167*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
168*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
169*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
170*993b0882SAndroid Build Coastguard Worker
171*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result,
172*993b0882SAndroid Build Coastguard Worker ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
173*993b0882SAndroid Build Coastguard Worker }
174*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,ProducesSmartRepliesFromCapturingMatches)175*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
176*993b0882SAndroid Build Coastguard Worker // Create test rules.
177*993b0882SAndroid Build Coastguard Worker // Rule: ^Text <reply> to <command>
178*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
179*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
180*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
181*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
182*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
183*993b0882SAndroid Build Coastguard Worker
184*993b0882SAndroid Build Coastguard Worker rules.Add(
185*993b0882SAndroid Build Coastguard Worker "<scripted_reply>",
186*993b0882SAndroid Build Coastguard Worker {"<^>", "text", "<captured_reply>", "to", "<command>"},
187*993b0882SAndroid Build Coastguard Worker /*callback=*/
188*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
189*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
190*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddCapturingMatchSmartReplySpec(
191*993b0882SAndroid Build Coastguard Worker /*match_id=*/0, &action_grammar_rules)},
192*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
193*993b0882SAndroid Build Coastguard Worker
194*993b0882SAndroid Build Coastguard Worker // <command> ::= unsubscribe | cancel | confirm | receive
195*993b0882SAndroid Build Coastguard Worker rules.Add("<command>", {"unsubscribe"});
196*993b0882SAndroid Build Coastguard Worker rules.Add("<command>", {"cancel"});
197*993b0882SAndroid Build Coastguard Worker rules.Add("<command>", {"confirm"});
198*993b0882SAndroid Build Coastguard Worker rules.Add("<command>", {"receive"});
199*993b0882SAndroid Build Coastguard Worker
200*993b0882SAndroid Build Coastguard Worker // <reply> ::= help | stop | cancel | yes
201*993b0882SAndroid Build Coastguard Worker rules.Add("<reply>", {"help"});
202*993b0882SAndroid Build Coastguard Worker rules.Add("<reply>", {"stop"});
203*993b0882SAndroid Build Coastguard Worker rules.Add("<reply>", {"cancel"});
204*993b0882SAndroid Build Coastguard Worker rules.Add("<reply>", {"yes"});
205*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<captured_reply>", {"<reply>"},
206*993b0882SAndroid Build Coastguard Worker /*value=*/0);
207*993b0882SAndroid Build Coastguard Worker
208*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
209*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
210*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
211*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
212*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
213*993b0882SAndroid Build Coastguard Worker
214*993b0882SAndroid Build Coastguard Worker {
215*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
216*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
217*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0,
218*993b0882SAndroid Build Coastguard Worker /*text=*/"Text YES to confirm your subscription"}}},
219*993b0882SAndroid Build Coastguard Worker &result));
220*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
221*993b0882SAndroid Build Coastguard Worker }
222*993b0882SAndroid Build Coastguard Worker
223*993b0882SAndroid Build Coastguard Worker {
224*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
225*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
226*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0,
227*993b0882SAndroid Build Coastguard Worker /*text=*/"text Stop to cancel your order"}}},
228*993b0882SAndroid Build Coastguard Worker &result));
229*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
230*993b0882SAndroid Build Coastguard Worker }
231*993b0882SAndroid Build Coastguard Worker }
232*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,ProducesAnnotationsForActions)233*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
234*993b0882SAndroid Build Coastguard Worker // Create test rules.
235*993b0882SAndroid Build Coastguard Worker // Rule: please dial <phone>
236*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
237*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
238*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
239*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
240*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
241*993b0882SAndroid Build Coastguard Worker
242*993b0882SAndroid Build Coastguard Worker rules.Add(
243*993b0882SAndroid Build Coastguard Worker "<call_phone>", {"please", "dial", "<phone>"},
244*993b0882SAndroid Build Coastguard Worker /*callback=*/
245*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
246*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
247*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
248*993b0882SAndroid Build Coastguard Worker /*annotations=*/{{0 /*value*/, "phone"}},
249*993b0882SAndroid Build Coastguard Worker &action_grammar_rules)},
250*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
251*993b0882SAndroid Build Coastguard Worker // phone ::= +00 00 000 00 00
252*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<phone>",
253*993b0882SAndroid Build Coastguard Worker {"+", "<2_digits>", "<2_digits>", "<3_digits>",
254*993b0882SAndroid Build Coastguard Worker "<2_digits>", "<2_digits>"},
255*993b0882SAndroid Build Coastguard Worker /*value=*/0);
256*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
257*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
258*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
259*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
260*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
261*993b0882SAndroid Build Coastguard Worker
262*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
263*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
264*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
265*993b0882SAndroid Build Coastguard Worker &result));
266*993b0882SAndroid Build Coastguard Worker
267*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
268*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result.front().annotations,
269*993b0882SAndroid Build Coastguard Worker ElementsAre(IsActionSuggestionAnnotation(
270*993b0882SAndroid Build Coastguard Worker "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
271*993b0882SAndroid Build Coastguard Worker }
272*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,HandlesLocales)273*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, HandlesLocales) {
274*993b0882SAndroid Build Coastguard Worker // Create test rules.
275*993b0882SAndroid Build Coastguard Worker // Rule: ^knock knock.?$ -> "Who's there?"
276*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
277*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
278*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
279*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map =
280*993b0882SAndroid Build Coastguard Worker LocaleShardMap::CreateLocaleShardMap({"", "fr-CH"});
281*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
282*993b0882SAndroid Build Coastguard Worker rules.Add(
283*993b0882SAndroid Build Coastguard Worker "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
284*993b0882SAndroid Build Coastguard Worker /*callback=*/
285*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
286*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
287*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
288*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
289*993b0882SAndroid Build Coastguard Worker rules.Add(
290*993b0882SAndroid Build Coastguard Worker "<toc>", {"<knock>"},
291*993b0882SAndroid Build Coastguard Worker /*callback=*/
292*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
293*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
294*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
295*993b0882SAndroid Build Coastguard Worker &action_grammar_rules),
296*993b0882SAndroid Build Coastguard Worker /*max_whitespace_gap=*/-1,
297*993b0882SAndroid Build Coastguard Worker /*case_sensitive=*/false,
298*993b0882SAndroid Build Coastguard Worker /*shard=*/1);
299*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
300*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
301*993b0882SAndroid Build Coastguard Worker // Set locales for rules.
302*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules->rules.back()->locale.emplace_back(
303*993b0882SAndroid Build Coastguard Worker new LanguageTagT);
304*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
305*993b0882SAndroid Build Coastguard Worker
306*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
307*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
308*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
309*993b0882SAndroid Build Coastguard Worker
310*993b0882SAndroid Build Coastguard Worker // Check default.
311*993b0882SAndroid Build Coastguard Worker {
312*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
313*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
314*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
315*993b0882SAndroid Build Coastguard Worker /*reference_time_ms_utc=*/0,
316*993b0882SAndroid Build Coastguard Worker /*reference_timezone=*/"UTC", /*annotations=*/{},
317*993b0882SAndroid Build Coastguard Worker /*detected_text_language_tags=*/"en"}}},
318*993b0882SAndroid Build Coastguard Worker &result));
319*993b0882SAndroid Build Coastguard Worker
320*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
321*993b0882SAndroid Build Coastguard Worker }
322*993b0882SAndroid Build Coastguard Worker
323*993b0882SAndroid Build Coastguard Worker // Check fr.
324*993b0882SAndroid Build Coastguard Worker {
325*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
326*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
327*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
328*993b0882SAndroid Build Coastguard Worker /*reference_time_ms_utc=*/0,
329*993b0882SAndroid Build Coastguard Worker /*reference_timezone=*/"UTC", /*annotations=*/{},
330*993b0882SAndroid Build Coastguard Worker /*detected_text_language_tags=*/"fr-CH"}}},
331*993b0882SAndroid Build Coastguard Worker &result));
332*993b0882SAndroid Build Coastguard Worker
333*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
334*993b0882SAndroid Build Coastguard Worker IsSmartReply("Qui est là?")));
335*993b0882SAndroid Build Coastguard Worker }
336*993b0882SAndroid Build Coastguard Worker }
337*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,HandlesAssertions)338*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, HandlesAssertions) {
339*993b0882SAndroid Build Coastguard Worker // Create test rules.
340*993b0882SAndroid Build Coastguard Worker // Rule: <flight> -> Track flight.
341*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
342*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
343*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
344*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
345*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
346*993b0882SAndroid Build Coastguard Worker rules.Add("<carrier>", {"lx"});
347*993b0882SAndroid Build Coastguard Worker rules.Add("<carrier>", {"aa"});
348*993b0882SAndroid Build Coastguard Worker rules.Add("<flight_code>", {"<2_digits>"});
349*993b0882SAndroid Build Coastguard Worker rules.Add("<flight_code>", {"<3_digits>"});
350*993b0882SAndroid Build Coastguard Worker rules.Add("<flight_code>", {"<4_digits>"});
351*993b0882SAndroid Build Coastguard Worker
352*993b0882SAndroid Build Coastguard Worker // Capture flight code.
353*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
354*993b0882SAndroid Build Coastguard Worker /*value=*/0);
355*993b0882SAndroid Build Coastguard Worker
356*993b0882SAndroid Build Coastguard Worker // Flight: carrier + flight code and check right context.
357*993b0882SAndroid Build Coastguard Worker rules.Add(
358*993b0882SAndroid Build Coastguard Worker "<track_flight>", {"<flight>", "<context_assertion>?"},
359*993b0882SAndroid Build Coastguard Worker /*callback=*/
360*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
361*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
362*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
363*993b0882SAndroid Build Coastguard Worker /*annotations=*/{{0 /*value*/, "flight"}},
364*993b0882SAndroid Build Coastguard Worker &action_grammar_rules)},
365*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
366*993b0882SAndroid Build Coastguard Worker
367*993b0882SAndroid Build Coastguard Worker // Exclude matches like: LX 38.00 etc.
368*993b0882SAndroid Build Coastguard Worker rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
369*993b0882SAndroid Build Coastguard Worker /*negative=*/true);
370*993b0882SAndroid Build Coastguard Worker
371*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
372*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
373*993b0882SAndroid Build Coastguard Worker
374*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
375*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
376*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
377*993b0882SAndroid Build Coastguard Worker
378*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
379*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
380*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
381*993b0882SAndroid Build Coastguard Worker &result));
382*993b0882SAndroid Build Coastguard Worker
383*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
384*993b0882SAndroid Build Coastguard Worker IsActionOfType("track_flight")));
385*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result[0].annotations,
386*993b0882SAndroid Build Coastguard Worker ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
387*993b0882SAndroid Build Coastguard Worker CodepointSpan{0, 4})));
388*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result[1].annotations,
389*993b0882SAndroid Build Coastguard Worker ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
390*993b0882SAndroid Build Coastguard Worker CodepointSpan{5, 10})));
391*993b0882SAndroid Build Coastguard Worker }
392*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,SetsFixedEntityData)393*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, SetsFixedEntityData) {
394*993b0882SAndroid Build Coastguard Worker // Create test rules.
395*993b0882SAndroid Build Coastguard Worker // Rule: ^hello there$
396*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
397*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
398*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
399*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
400*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
401*993b0882SAndroid Build Coastguard Worker
402*993b0882SAndroid Build Coastguard Worker // Create smart reply and static entity data.
403*993b0882SAndroid Build Coastguard Worker const int spec_id =
404*993b0882SAndroid Build Coastguard Worker AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
405*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
406*993b0882SAndroid Build Coastguard Worker entity_data_builder_->NewRoot();
407*993b0882SAndroid Build Coastguard Worker entity_data->Set("person", "Kenobi");
408*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
409*993b0882SAndroid Build Coastguard Worker entity_data->Serialize();
410*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->action->entity_data.reset(
411*993b0882SAndroid Build Coastguard Worker new ActionsEntityDataT);
412*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->action->entity_data->text =
413*993b0882SAndroid Build Coastguard Worker "I have the high ground.";
414*993b0882SAndroid Build Coastguard Worker
415*993b0882SAndroid Build Coastguard Worker rules.Add(
416*993b0882SAndroid Build Coastguard Worker "<greeting>", {"<^>", "hello", "there", "<$>"},
417*993b0882SAndroid Build Coastguard Worker /*callback=*/
418*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
419*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
420*993b0882SAndroid Build Coastguard Worker AddRuleMatch({spec_id}, &action_grammar_rules));
421*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
422*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
423*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
424*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
425*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get(),
426*993b0882SAndroid Build Coastguard Worker entity_data_builder_.get());
427*993b0882SAndroid Build Coastguard Worker
428*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
429*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
430*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
431*993b0882SAndroid Build Coastguard Worker
432*993b0882SAndroid Build Coastguard Worker // Check the produces smart replies.
433*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
434*993b0882SAndroid Build Coastguard Worker
435*993b0882SAndroid Build Coastguard Worker // Check entity data.
436*993b0882SAndroid Build Coastguard Worker const flatbuffers::Table* entity =
437*993b0882SAndroid Build Coastguard Worker flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
438*993b0882SAndroid Build Coastguard Worker result[0].serialized_entity_data.data()));
439*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
440*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
441*993b0882SAndroid Build Coastguard Worker "I have the high ground.");
442*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
443*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
444*993b0882SAndroid Build Coastguard Worker "Kenobi");
445*993b0882SAndroid Build Coastguard Worker }
446*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,SetsEntityDataFromCapturingMatches)447*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
448*993b0882SAndroid Build Coastguard Worker // Create test rules.
449*993b0882SAndroid Build Coastguard Worker // Rule: ^hello there$
450*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
451*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
452*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
453*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
454*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
455*993b0882SAndroid Build Coastguard Worker
456*993b0882SAndroid Build Coastguard Worker // Create smart reply and static entity data.
457*993b0882SAndroid Build Coastguard Worker const int spec_id =
458*993b0882SAndroid Build Coastguard Worker AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
459*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
460*993b0882SAndroid Build Coastguard Worker entity_data_builder_->NewRoot();
461*993b0882SAndroid Build Coastguard Worker entity_data->Set("person", "Kenobi");
462*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
463*993b0882SAndroid Build Coastguard Worker entity_data->Serialize();
464*993b0882SAndroid Build Coastguard Worker
465*993b0882SAndroid Build Coastguard Worker // Specify results for capturing matches.
466*993b0882SAndroid Build Coastguard Worker const int greeting_match_id = 0;
467*993b0882SAndroid Build Coastguard Worker const int location_match_id = 1;
468*993b0882SAndroid Build Coastguard Worker {
469*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
470*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
471*993b0882SAndroid Build Coastguard Worker RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
472*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.back().get();
473*993b0882SAndroid Build Coastguard Worker group->group_id = greeting_match_id;
474*993b0882SAndroid Build Coastguard Worker group->entity_field.reset(new FlatbufferFieldPathT);
475*993b0882SAndroid Build Coastguard Worker group->entity_field->field.emplace_back(new FlatbufferFieldT);
476*993b0882SAndroid Build Coastguard Worker group->entity_field->field.back()->field_name = "greeting";
477*993b0882SAndroid Build Coastguard Worker }
478*993b0882SAndroid Build Coastguard Worker {
479*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
480*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
481*993b0882SAndroid Build Coastguard Worker RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
482*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.back().get();
483*993b0882SAndroid Build Coastguard Worker group->group_id = location_match_id;
484*993b0882SAndroid Build Coastguard Worker group->entity_field.reset(new FlatbufferFieldPathT);
485*993b0882SAndroid Build Coastguard Worker group->entity_field->field.emplace_back(new FlatbufferFieldT);
486*993b0882SAndroid Build Coastguard Worker group->entity_field->field.back()->field_name = "location";
487*993b0882SAndroid Build Coastguard Worker }
488*993b0882SAndroid Build Coastguard Worker
489*993b0882SAndroid Build Coastguard Worker rules.Add("<location>", {"there"});
490*993b0882SAndroid Build Coastguard Worker rules.Add("<location>", {"here"});
491*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<captured_location>", {"<location>"},
492*993b0882SAndroid Build Coastguard Worker /*value=*/location_match_id);
493*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
494*993b0882SAndroid Build Coastguard Worker /*value=*/greeting_match_id);
495*993b0882SAndroid Build Coastguard Worker rules.Add(
496*993b0882SAndroid Build Coastguard Worker "<test>", {"<^>", "<greeting>", "<$>"},
497*993b0882SAndroid Build Coastguard Worker /*callback=*/
498*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
499*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
500*993b0882SAndroid Build Coastguard Worker AddRuleMatch({spec_id}, &action_grammar_rules));
501*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
502*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
503*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
504*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
505*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get(),
506*993b0882SAndroid Build Coastguard Worker entity_data_builder_.get());
507*993b0882SAndroid Build Coastguard Worker
508*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
509*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
510*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
511*993b0882SAndroid Build Coastguard Worker
512*993b0882SAndroid Build Coastguard Worker // Check the produces smart replies.
513*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
514*993b0882SAndroid Build Coastguard Worker
515*993b0882SAndroid Build Coastguard Worker // Check entity data.
516*993b0882SAndroid Build Coastguard Worker const flatbuffers::Table* entity =
517*993b0882SAndroid Build Coastguard Worker flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
518*993b0882SAndroid Build Coastguard Worker result[0].serialized_entity_data.data()));
519*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
520*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
521*993b0882SAndroid Build Coastguard Worker "Hello there");
522*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
523*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
524*993b0882SAndroid Build Coastguard Worker "there");
525*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
526*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
527*993b0882SAndroid Build Coastguard Worker "Kenobi");
528*993b0882SAndroid Build Coastguard Worker }
529*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,SetsFixedEntityDataFromCapturingGroups)530*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
531*993b0882SAndroid Build Coastguard Worker // Create test rules.
532*993b0882SAndroid Build Coastguard Worker // Rule: ^hello there$
533*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
534*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
535*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
536*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
537*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
538*993b0882SAndroid Build Coastguard Worker
539*993b0882SAndroid Build Coastguard Worker // Create smart reply.
540*993b0882SAndroid Build Coastguard Worker const int spec_id =
541*993b0882SAndroid Build Coastguard Worker AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
542*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
543*993b0882SAndroid Build Coastguard Worker new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
544*993b0882SAndroid Build Coastguard Worker RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
545*993b0882SAndroid Build Coastguard Worker action_grammar_rules.actions[spec_id]->capturing_group.back().get();
546*993b0882SAndroid Build Coastguard Worker group->group_id = 0;
547*993b0882SAndroid Build Coastguard Worker group->entity_data.reset(new ActionsEntityDataT);
548*993b0882SAndroid Build Coastguard Worker group->entity_data->text = "You are a bold one.";
549*993b0882SAndroid Build Coastguard Worker
550*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
551*993b0882SAndroid Build Coastguard Worker /*value=*/0);
552*993b0882SAndroid Build Coastguard Worker rules.Add(
553*993b0882SAndroid Build Coastguard Worker "<test>", {"<greeting>"},
554*993b0882SAndroid Build Coastguard Worker /*callback=*/
555*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
556*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
557*993b0882SAndroid Build Coastguard Worker AddRuleMatch({spec_id}, &action_grammar_rules));
558*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
559*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
560*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
561*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
562*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get(),
563*993b0882SAndroid Build Coastguard Worker entity_data_builder_.get());
564*993b0882SAndroid Build Coastguard Worker
565*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
566*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
567*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
568*993b0882SAndroid Build Coastguard Worker
569*993b0882SAndroid Build Coastguard Worker // Check the produces smart replies.
570*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
571*993b0882SAndroid Build Coastguard Worker
572*993b0882SAndroid Build Coastguard Worker // Check entity data.
573*993b0882SAndroid Build Coastguard Worker const flatbuffers::Table* entity =
574*993b0882SAndroid Build Coastguard Worker flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
575*993b0882SAndroid Build Coastguard Worker result[0].serialized_entity_data.data()));
576*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
577*993b0882SAndroid Build Coastguard Worker entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
578*993b0882SAndroid Build Coastguard Worker "You are a bold one.");
579*993b0882SAndroid Build Coastguard Worker }
580*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,ProducesActionsWithAnnotations)581*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
582*993b0882SAndroid Build Coastguard Worker // Create test rules.
583*993b0882SAndroid Build Coastguard Worker // Rule: please dial <phone>
584*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
585*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
586*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
587*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
588*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
589*993b0882SAndroid Build Coastguard Worker rules.Add(
590*993b0882SAndroid Build Coastguard Worker "<call_phone>", {"please", "dial", "<phone>"},
591*993b0882SAndroid Build Coastguard Worker /*callback=*/
592*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
593*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
594*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
595*993b0882SAndroid Build Coastguard Worker /*annotations=*/
596*993b0882SAndroid Build Coastguard Worker {{0 /*value*/, "phone",
597*993b0882SAndroid Build Coastguard Worker /*use_annotation_match=*/true}},
598*993b0882SAndroid Build Coastguard Worker &action_grammar_rules)},
599*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
600*993b0882SAndroid Build Coastguard Worker rules.AddValueMapping("<phone>", {"<phone_annotation>"},
601*993b0882SAndroid Build Coastguard Worker /*value=*/0);
602*993b0882SAndroid Build Coastguard Worker
603*993b0882SAndroid Build Coastguard Worker grammar::Ir ir = rules.Finalize(
604*993b0882SAndroid Build Coastguard Worker /*predefined_nonterminals=*/{"<phone_annotation>"});
605*993b0882SAndroid Build Coastguard Worker ir.Serialize(/*include_debug_information=*/false,
606*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
607*993b0882SAndroid Build Coastguard Worker
608*993b0882SAndroid Build Coastguard Worker // Map "phone" annotation to "<phone_annotation>" nonterminal.
609*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
610*993b0882SAndroid Build Coastguard Worker new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
611*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
612*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
613*993b0882SAndroid Build Coastguard Worker ir.GetNonterminalForName("<phone_annotation>");
614*993b0882SAndroid Build Coastguard Worker
615*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
616*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
617*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get());
618*993b0882SAndroid Build Coastguard Worker
619*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
620*993b0882SAndroid Build Coastguard Worker
621*993b0882SAndroid Build Coastguard Worker // Sanity check that no result are produced when no annotations are provided.
622*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
623*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
624*993b0882SAndroid Build Coastguard Worker &result));
625*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, IsEmpty());
626*993b0882SAndroid Build Coastguard Worker
627*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
628*993b0882SAndroid Build Coastguard Worker {/*messages=*/{
629*993b0882SAndroid Build Coastguard Worker {/*user_id=*/0,
630*993b0882SAndroid Build Coastguard Worker /*text=*/"Please dial +41 79 123 45 67",
631*993b0882SAndroid Build Coastguard Worker /*reference_time_ms_utc=*/0,
632*993b0882SAndroid Build Coastguard Worker /*reference_timezone=*/"UTC",
633*993b0882SAndroid Build Coastguard Worker /*annotations=*/
634*993b0882SAndroid Build Coastguard Worker {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
635*993b0882SAndroid Build Coastguard Worker &result));
636*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
637*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result.front().annotations,
638*993b0882SAndroid Build Coastguard Worker ElementsAre(IsActionSuggestionAnnotation(
639*993b0882SAndroid Build Coastguard Worker "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
640*993b0882SAndroid Build Coastguard Worker }
641*993b0882SAndroid Build Coastguard Worker
TEST_F(GrammarActionsTest,HandlesExclusions)642*993b0882SAndroid Build Coastguard Worker TEST_F(GrammarActionsTest, HandlesExclusions) {
643*993b0882SAndroid Build Coastguard Worker // Create test rules.
644*993b0882SAndroid Build Coastguard Worker RulesModel_::GrammarRulesT action_grammar_rules;
645*993b0882SAndroid Build Coastguard Worker SetTokenizerOptions(&action_grammar_rules);
646*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.reset(new grammar::RulesSetT);
647*993b0882SAndroid Build Coastguard Worker
648*993b0882SAndroid Build Coastguard Worker LocaleShardMap locale_shard_map = LocaleShardMap::CreateLocaleShardMap({""});
649*993b0882SAndroid Build Coastguard Worker grammar::Rules rules(locale_shard_map);
650*993b0882SAndroid Build Coastguard Worker rules.Add("<excluded>", {"be", "safe"});
651*993b0882SAndroid Build Coastguard Worker rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
652*993b0882SAndroid Build Coastguard Worker /*excluded_nonterminal=*/"<excluded>");
653*993b0882SAndroid Build Coastguard Worker
654*993b0882SAndroid Build Coastguard Worker rules.Add(
655*993b0882SAndroid Build Coastguard Worker "<set_reminder>",
656*993b0882SAndroid Build Coastguard Worker {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
657*993b0882SAndroid Build Coastguard Worker /*callback=*/
658*993b0882SAndroid Build Coastguard Worker static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
659*993b0882SAndroid Build Coastguard Worker /*callback_param=*/
660*993b0882SAndroid Build Coastguard Worker AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
661*993b0882SAndroid Build Coastguard Worker /*annotations=*/
662*993b0882SAndroid Build Coastguard Worker {}, &action_grammar_rules)},
663*993b0882SAndroid Build Coastguard Worker &action_grammar_rules));
664*993b0882SAndroid Build Coastguard Worker
665*993b0882SAndroid Build Coastguard Worker rules.Finalize().Serialize(/*include_debug_information=*/false,
666*993b0882SAndroid Build Coastguard Worker action_grammar_rules.rules.get());
667*993b0882SAndroid Build Coastguard Worker OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
668*993b0882SAndroid Build Coastguard Worker PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
669*993b0882SAndroid Build Coastguard Worker TestGrammarActions grammar_actions(unilib_.get(), model.get(),
670*993b0882SAndroid Build Coastguard Worker entity_data_builder_.get());
671*993b0882SAndroid Build Coastguard Worker
672*993b0882SAndroid Build Coastguard Worker {
673*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
674*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
675*993b0882SAndroid Build Coastguard Worker {/*messages=*/{
676*993b0882SAndroid Build Coastguard Worker {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
677*993b0882SAndroid Build Coastguard Worker &result));
678*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
679*993b0882SAndroid Build Coastguard Worker }
680*993b0882SAndroid Build Coastguard Worker
681*993b0882SAndroid Build Coastguard Worker {
682*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
683*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
684*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
685*993b0882SAndroid Build Coastguard Worker &result));
686*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
687*993b0882SAndroid Build Coastguard Worker }
688*993b0882SAndroid Build Coastguard Worker
689*993b0882SAndroid Build Coastguard Worker {
690*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
691*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
692*993b0882SAndroid Build Coastguard Worker {/*messages=*/{
693*993b0882SAndroid Build Coastguard Worker {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
694*993b0882SAndroid Build Coastguard Worker &result));
695*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
696*993b0882SAndroid Build Coastguard Worker }
697*993b0882SAndroid Build Coastguard Worker
698*993b0882SAndroid Build Coastguard Worker {
699*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion> result;
700*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(grammar_actions.SuggestActions(
701*993b0882SAndroid Build Coastguard Worker {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
702*993b0882SAndroid Build Coastguard Worker &result));
703*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(result, IsEmpty());
704*993b0882SAndroid Build Coastguard Worker }
705*993b0882SAndroid Build Coastguard Worker }
706*993b0882SAndroid Build Coastguard Worker
707*993b0882SAndroid Build Coastguard Worker } // namespace
708*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
709