1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/semantics/composer.h"
18
19 #include "utils/flatbuffers/flatbuffers.h"
20 #include "utils/flatbuffers/reflection.h"
21 #include "utils/grammar/parsing/derivation.h"
22 #include "utils/grammar/parsing/parser.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/semantics/expression_generated.h"
25 #include "utils/grammar/testing/utils.h"
26 #include "utils/grammar/testing/value_generated.h"
27 #include "utils/grammar/types.h"
28 #include "utils/grammar/utils/rules.h"
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31
32 namespace libtextclassifier3::grammar {
33 namespace {
34
35 using ::testing::ElementsAre;
36
37 class SemanticComposerTest : public GrammarTest {};
38
TEST_F(SemanticComposerTest,EvaluatesSimpleMapping)39 TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
40 RulesSetT model;
41 grammar::LocaleShardMap locale_shard_map =
42 grammar::LocaleShardMap::CreateLocaleShardMap({""});
43 Rules rules(locale_shard_map);
44 const int test_value_type =
45 TypeIdForName(semantic_values_schema_.get(),
46 "libtextclassifier3.grammar.TestValue")
47 .value();
48 {
49 rules.Add("<month>", {"january"},
50 static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
51 /*callback_param=*/model.semantic_expression.size());
52 TestValueT value;
53 value.value = 1;
54 const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
55 ConstValueExpressionT const_value;
56 const_value.base_type = reflection::BaseType::Obj;
57 const_value.type = test_value_type;
58 const_value.value.assign(serialized_value.begin(), serialized_value.end());
59 model.semantic_expression.emplace_back(new SemanticExpressionT);
60 model.semantic_expression.back()->expression.Set(const_value);
61 }
62 {
63 rules.Add("<month>", {"february"},
64 static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
65 /*callback_param=*/model.semantic_expression.size());
66 TestValueT value;
67 value.value = 2;
68 const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
69 ConstValueExpressionT const_value;
70 const_value.base_type = reflection::BaseType::Obj;
71 const_value.type = test_value_type;
72 const_value.value.assign(serialized_value.begin(), serialized_value.end());
73 model.semantic_expression.emplace_back(new SemanticExpressionT);
74 model.semantic_expression.back()->expression.Set(const_value);
75 }
76 const int kMonth = 0;
77 rules.Add("<month_rule>", {"<month>"},
78 static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
79 rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
80 const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
81 Parser parser(unilib_.get(),
82 flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
83 SemanticComposer composer(semantic_values_schema_.get());
84
85 {
86 const TextContext text = TextContextForText("Month: January");
87 const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
88 EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 14)));
89
90 StatusOr<const SemanticValue*> maybe_value =
91 composer.Eval(text, derivations.front(), &arena_);
92 EXPECT_TRUE(maybe_value.ok());
93
94 const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
95 EXPECT_EQ(value->value(), 1);
96 }
97
98 {
99 const TextContext text = TextContextForText("Month: February");
100 const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
101 EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 15)));
102
103 StatusOr<const SemanticValue*> maybe_value =
104 composer.Eval(text, derivations.front(), &arena_);
105 EXPECT_TRUE(maybe_value.ok());
106
107 const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
108 EXPECT_EQ(value->value(), 2);
109 }
110 }
111
TEST_F(SemanticComposerTest,RecursivelyEvaluatesConstituents)112 TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) {
113 RulesSetT model;
114 grammar::LocaleShardMap locale_shard_map =
115 grammar::LocaleShardMap::CreateLocaleShardMap({""});
116 Rules rules(locale_shard_map);
117 const int test_value_type =
118 TypeIdForName(semantic_values_schema_.get(),
119 "libtextclassifier3.grammar.TestValue")
120 .value();
121 constexpr int kDateRule = 0;
122 {
123 rules.Add("<month>", {"january"},
124 static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
125 /*callback_param=*/model.semantic_expression.size());
126 TestValueT value;
127 value.value = 42;
128 const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
129 ConstValueExpressionT const_value;
130 const_value.type = test_value_type;
131 const_value.base_type = reflection::BaseType::Obj;
132 const_value.value.assign(serialized_value.begin(), serialized_value.end());
133 model.semantic_expression.emplace_back(new SemanticExpressionT);
134 model.semantic_expression.back()->expression.Set(const_value);
135 }
136 {
137 // Define constituents of the rule.
138 // TODO(smillius): Add support in the rules builder to directly specify
139 // constituent ids in the rule, e.g. `<date> ::= <month>@0? <4_digits>`.
140 rules.Add("<date_@0>", {"<month>"},
141 static_cast<CallbackId>(DefaultCallback::kMapping),
142 /*callback_param=*/1);
143 rules.Add("<date>", {"<date_@0>?", "<4_digits>"},
144 static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
145 /*callback_param=*/model.semantic_expression.size());
146 ConstituentExpressionT constituent;
147 constituent.id = 1;
148 model.semantic_expression.emplace_back(new SemanticExpressionT);
149 model.semantic_expression.back()->expression.Set(constituent);
150 rules.Add("<date_rule>", {"<date>"},
151 static_cast<CallbackId>(DefaultCallback::kRootRule),
152 /*callback_param=*/kDateRule);
153 }
154
155 rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
156 const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
157 Parser parser(unilib_.get(),
158 flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
159 SemanticComposer composer(semantic_values_schema_.get());
160
161 {
162 const TextContext text = TextContextForText("Event: January 2020");
163 const std::vector<Derivation> derivations =
164 ValidDeduplicatedDerivations(parser.Parse(text, &arena_));
165 EXPECT_THAT(derivations, ElementsAre(IsDerivation(kDateRule, 7, 19)));
166
167 StatusOr<const SemanticValue*> maybe_value =
168 composer.Eval(text, derivations.front(), &arena_);
169 EXPECT_TRUE(maybe_value.ok());
170
171 const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
172 EXPECT_EQ(value->value(), 42);
173 }
174 }
175
176 } // namespace
177 } // namespace libtextclassifier3::grammar
178