xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/semantics/composer_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/semantics/composer.h"
18 
19 #include "utils/flatbuffers/flatbuffers.h"
20 #include "utils/flatbuffers/reflection.h"
21 #include "utils/grammar/parsing/derivation.h"
22 #include "utils/grammar/parsing/parser.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/semantics/expression_generated.h"
25 #include "utils/grammar/testing/utils.h"
26 #include "utils/grammar/testing/value_generated.h"
27 #include "utils/grammar/types.h"
28 #include "utils/grammar/utils/rules.h"
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31 
32 namespace libtextclassifier3::grammar {
33 namespace {
34 
35 using ::testing::ElementsAre;
36 
37 class SemanticComposerTest : public GrammarTest {};
38 
TEST_F(SemanticComposerTest,EvaluatesSimpleMapping)39 TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
40   RulesSetT model;
41   grammar::LocaleShardMap locale_shard_map =
42       grammar::LocaleShardMap::CreateLocaleShardMap({""});
43   Rules rules(locale_shard_map);
44   const int test_value_type =
45       TypeIdForName(semantic_values_schema_.get(),
46                     "libtextclassifier3.grammar.TestValue")
47           .value();
48   {
49     rules.Add("<month>", {"january"},
50               static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
51               /*callback_param=*/model.semantic_expression.size());
52     TestValueT value;
53     value.value = 1;
54     const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
55     ConstValueExpressionT const_value;
56     const_value.base_type = reflection::BaseType::Obj;
57     const_value.type = test_value_type;
58     const_value.value.assign(serialized_value.begin(), serialized_value.end());
59     model.semantic_expression.emplace_back(new SemanticExpressionT);
60     model.semantic_expression.back()->expression.Set(const_value);
61   }
62   {
63     rules.Add("<month>", {"february"},
64               static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
65               /*callback_param=*/model.semantic_expression.size());
66     TestValueT value;
67     value.value = 2;
68     const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
69     ConstValueExpressionT const_value;
70     const_value.base_type = reflection::BaseType::Obj;
71     const_value.type = test_value_type;
72     const_value.value.assign(serialized_value.begin(), serialized_value.end());
73     model.semantic_expression.emplace_back(new SemanticExpressionT);
74     model.semantic_expression.back()->expression.Set(const_value);
75   }
76   const int kMonth = 0;
77   rules.Add("<month_rule>", {"<month>"},
78             static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
79   rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
80   const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
81   Parser parser(unilib_.get(),
82                 flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
83   SemanticComposer composer(semantic_values_schema_.get());
84 
85   {
86     const TextContext text = TextContextForText("Month: January");
87     const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
88     EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 14)));
89 
90     StatusOr<const SemanticValue*> maybe_value =
91         composer.Eval(text, derivations.front(), &arena_);
92     EXPECT_TRUE(maybe_value.ok());
93 
94     const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
95     EXPECT_EQ(value->value(), 1);
96   }
97 
98   {
99     const TextContext text = TextContextForText("Month: February");
100     const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
101     EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 15)));
102 
103     StatusOr<const SemanticValue*> maybe_value =
104         composer.Eval(text, derivations.front(), &arena_);
105     EXPECT_TRUE(maybe_value.ok());
106 
107     const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
108     EXPECT_EQ(value->value(), 2);
109   }
110 }
111 
TEST_F(SemanticComposerTest,RecursivelyEvaluatesConstituents)112 TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) {
113   RulesSetT model;
114   grammar::LocaleShardMap locale_shard_map =
115       grammar::LocaleShardMap::CreateLocaleShardMap({""});
116   Rules rules(locale_shard_map);
117   const int test_value_type =
118       TypeIdForName(semantic_values_schema_.get(),
119                     "libtextclassifier3.grammar.TestValue")
120           .value();
121   constexpr int kDateRule = 0;
122   {
123     rules.Add("<month>", {"january"},
124               static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
125               /*callback_param=*/model.semantic_expression.size());
126     TestValueT value;
127     value.value = 42;
128     const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
129     ConstValueExpressionT const_value;
130     const_value.type = test_value_type;
131     const_value.base_type = reflection::BaseType::Obj;
132     const_value.value.assign(serialized_value.begin(), serialized_value.end());
133     model.semantic_expression.emplace_back(new SemanticExpressionT);
134     model.semantic_expression.back()->expression.Set(const_value);
135   }
136   {
137     // Define constituents of the rule.
138     // TODO(smillius): Add support in the rules builder to directly specify
139     // constituent ids in the rule, e.g. `<date> ::= <month>@0? <4_digits>`.
140     rules.Add("<date_@0>", {"<month>"},
141               static_cast<CallbackId>(DefaultCallback::kMapping),
142               /*callback_param=*/1);
143     rules.Add("<date>", {"<date_@0>?", "<4_digits>"},
144               static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
145               /*callback_param=*/model.semantic_expression.size());
146     ConstituentExpressionT constituent;
147     constituent.id = 1;
148     model.semantic_expression.emplace_back(new SemanticExpressionT);
149     model.semantic_expression.back()->expression.Set(constituent);
150     rules.Add("<date_rule>", {"<date>"},
151               static_cast<CallbackId>(DefaultCallback::kRootRule),
152               /*callback_param=*/kDateRule);
153   }
154 
155   rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
156   const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
157   Parser parser(unilib_.get(),
158                 flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
159   SemanticComposer composer(semantic_values_schema_.get());
160 
161   {
162     const TextContext text = TextContextForText("Event: January 2020");
163     const std::vector<Derivation> derivations =
164         ValidDeduplicatedDerivations(parser.Parse(text, &arena_));
165     EXPECT_THAT(derivations, ElementsAre(IsDerivation(kDateRule, 7, 19)));
166 
167     StatusOr<const SemanticValue*> maybe_value =
168         composer.Eval(text, derivations.front(), &arena_);
169     EXPECT_TRUE(maybe_value.ok());
170 
171     const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
172     EXPECT_EQ(value->value(), 42);
173   }
174 }
175 
176 }  // namespace
177 }  // namespace libtextclassifier3::grammar
178