1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/composer.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/status_macros.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/compose-eval.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/const-eval.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/constituent-eval.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/merge-values-eval.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/parse-number-eval.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/span-eval.h"
27*993b0882SAndroid Build Coastguard Worker
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3::grammar {
29*993b0882SAndroid Build Coastguard Worker namespace {
30*993b0882SAndroid Build Coastguard Worker
31*993b0882SAndroid Build Coastguard Worker // Gathers all constituents of a rule and index them.
32*993b0882SAndroid Build Coastguard Worker // The constituents are numbered in the rule construction. But consituents could
33*993b0882SAndroid Build Coastguard Worker // be in optional parts of the rule and might not be present in a match.
34*993b0882SAndroid Build Coastguard Worker // This finds all constituents that are present in a match and allows to
35*993b0882SAndroid Build Coastguard Worker // retrieve them by their index.
GatherConstituents(const ParseTree * root)36*993b0882SAndroid Build Coastguard Worker std::unordered_map<int, const ParseTree*> GatherConstituents(
37*993b0882SAndroid Build Coastguard Worker const ParseTree* root) {
38*993b0882SAndroid Build Coastguard Worker std::unordered_map<int, const ParseTree*> constituents;
39*993b0882SAndroid Build Coastguard Worker Traverse(root, [root, &constituents](const ParseTree* node) {
40*993b0882SAndroid Build Coastguard Worker switch (node->type) {
41*993b0882SAndroid Build Coastguard Worker case ParseTree::Type::kMapping:
42*993b0882SAndroid Build Coastguard Worker TC3_CHECK(node->IsUnaryRule());
43*993b0882SAndroid Build Coastguard Worker constituents[static_cast<const MappingNode*>(node)->id] =
44*993b0882SAndroid Build Coastguard Worker node->unary_rule_rhs();
45*993b0882SAndroid Build Coastguard Worker return false;
46*993b0882SAndroid Build Coastguard Worker case ParseTree::Type::kDefault:
47*993b0882SAndroid Build Coastguard Worker // Continue traversal.
48*993b0882SAndroid Build Coastguard Worker return true;
49*993b0882SAndroid Build Coastguard Worker default:
50*993b0882SAndroid Build Coastguard Worker // Don't continue the traversal if we are not at the root node.
51*993b0882SAndroid Build Coastguard Worker // This could e.g. be an assertion node.
52*993b0882SAndroid Build Coastguard Worker return (node == root);
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker });
55*993b0882SAndroid Build Coastguard Worker return constituents;
56*993b0882SAndroid Build Coastguard Worker }
57*993b0882SAndroid Build Coastguard Worker
58*993b0882SAndroid Build Coastguard Worker } // namespace
59*993b0882SAndroid Build Coastguard Worker
SemanticComposer(const reflection::Schema * semantic_values_schema)60*993b0882SAndroid Build Coastguard Worker SemanticComposer::SemanticComposer(
61*993b0882SAndroid Build Coastguard Worker const reflection::Schema* semantic_values_schema) {
62*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
63*993b0882SAndroid Build Coastguard Worker std::make_unique<ArithmeticExpressionEvaluator>(this));
64*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
65*993b0882SAndroid Build Coastguard Worker std::make_unique<ConstituentEvaluator>());
66*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
67*993b0882SAndroid Build Coastguard Worker std::make_unique<ParseNumberEvaluator>(this));
68*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
69*993b0882SAndroid Build Coastguard Worker std::make_unique<SpanAsStringEvaluator>());
70*993b0882SAndroid Build Coastguard Worker if (semantic_values_schema != nullptr) {
71*993b0882SAndroid Build Coastguard Worker // Register semantic functions.
72*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(
73*993b0882SAndroid Build Coastguard Worker SemanticExpression_::Expression_ComposeExpression,
74*993b0882SAndroid Build Coastguard Worker std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
75*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(
76*993b0882SAndroid Build Coastguard Worker SemanticExpression_::Expression_ConstValueExpression,
77*993b0882SAndroid Build Coastguard Worker std::make_unique<ConstEvaluator>(semantic_values_schema));
78*993b0882SAndroid Build Coastguard Worker evaluators_.emplace(
79*993b0882SAndroid Build Coastguard Worker SemanticExpression_::Expression_MergeValueExpression,
80*993b0882SAndroid Build Coastguard Worker std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
81*993b0882SAndroid Build Coastguard Worker }
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker
Eval(const TextContext & text_context,const Derivation & derivation,UnsafeArena * arena) const84*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Eval(
85*993b0882SAndroid Build Coastguard Worker const TextContext& text_context, const Derivation& derivation,
86*993b0882SAndroid Build Coastguard Worker UnsafeArena* arena) const {
87*993b0882SAndroid Build Coastguard Worker if (!derivation.parse_tree->IsUnaryRule() ||
88*993b0882SAndroid Build Coastguard Worker derivation.parse_tree->unary_rule_rhs()->type !=
89*993b0882SAndroid Build Coastguard Worker ParseTree::Type::kExpression) {
90*993b0882SAndroid Build Coastguard Worker return nullptr;
91*993b0882SAndroid Build Coastguard Worker }
92*993b0882SAndroid Build Coastguard Worker return Eval(text_context,
93*993b0882SAndroid Build Coastguard Worker static_cast<const SemanticExpressionNode*>(
94*993b0882SAndroid Build Coastguard Worker derivation.parse_tree->unary_rule_rhs()),
95*993b0882SAndroid Build Coastguard Worker arena);
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker
Eval(const TextContext & text_context,const SemanticExpressionNode * derivation,UnsafeArena * arena) const98*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Eval(
99*993b0882SAndroid Build Coastguard Worker const TextContext& text_context, const SemanticExpressionNode* derivation,
100*993b0882SAndroid Build Coastguard Worker UnsafeArena* arena) const {
101*993b0882SAndroid Build Coastguard Worker // Evaluate constituents.
102*993b0882SAndroid Build Coastguard Worker EvalContext context{&text_context, derivation};
103*993b0882SAndroid Build Coastguard Worker for (const auto& [constituent_index, constituent] :
104*993b0882SAndroid Build Coastguard Worker GatherConstituents(derivation)) {
105*993b0882SAndroid Build Coastguard Worker if (constituent->type == ParseTree::Type::kExpression) {
106*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(
107*993b0882SAndroid Build Coastguard Worker context.rule_constituents[constituent_index],
108*993b0882SAndroid Build Coastguard Worker Eval(text_context,
109*993b0882SAndroid Build Coastguard Worker static_cast<const SemanticExpressionNode*>(constituent), arena));
110*993b0882SAndroid Build Coastguard Worker } else {
111*993b0882SAndroid Build Coastguard Worker // Just use the text of the constituent if no semantic expression was
112*993b0882SAndroid Build Coastguard Worker // defined.
113*993b0882SAndroid Build Coastguard Worker context.rule_constituents[constituent_index] = SemanticValue::Create(
114*993b0882SAndroid Build Coastguard Worker text_context.Span(constituent->codepoint_span), arena);
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker }
117*993b0882SAndroid Build Coastguard Worker return Apply(context, derivation->expression, arena);
118*993b0882SAndroid Build Coastguard Worker }
119*993b0882SAndroid Build Coastguard Worker
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const120*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Apply(
121*993b0882SAndroid Build Coastguard Worker const EvalContext& context, const SemanticExpression* expression,
122*993b0882SAndroid Build Coastguard Worker UnsafeArena* arena) const {
123*993b0882SAndroid Build Coastguard Worker const auto handler_it = evaluators_.find(expression->expression_type());
124*993b0882SAndroid Build Coastguard Worker if (handler_it == evaluators_.end()) {
125*993b0882SAndroid Build Coastguard Worker return Status(StatusCode::INVALID_ARGUMENT,
126*993b0882SAndroid Build Coastguard Worker std::string("Unhandled expression type: ") +
127*993b0882SAndroid Build Coastguard Worker EnumNameExpression(expression->expression_type()));
128*993b0882SAndroid Build Coastguard Worker }
129*993b0882SAndroid Build Coastguard Worker return handler_it->second->Apply(context, expression, arena);
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker
132*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3::grammar
133