xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/semantics/composer.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/composer.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/status_macros.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/compose-eval.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/const-eval.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/constituent-eval.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/merge-values-eval.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/parse-number-eval.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/semantics/evaluators/span-eval.h"
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3::grammar {
29*993b0882SAndroid Build Coastguard Worker namespace {
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker // Gathers all constituents of a rule and index them.
32*993b0882SAndroid Build Coastguard Worker // The constituents are numbered in the rule construction. But consituents could
33*993b0882SAndroid Build Coastguard Worker // be in optional parts of the rule and might not be present in a match.
34*993b0882SAndroid Build Coastguard Worker // This finds all constituents that are present in a match and allows to
35*993b0882SAndroid Build Coastguard Worker // retrieve them by their index.
GatherConstituents(const ParseTree * root)36*993b0882SAndroid Build Coastguard Worker std::unordered_map<int, const ParseTree*> GatherConstituents(
37*993b0882SAndroid Build Coastguard Worker     const ParseTree* root) {
38*993b0882SAndroid Build Coastguard Worker   std::unordered_map<int, const ParseTree*> constituents;
39*993b0882SAndroid Build Coastguard Worker   Traverse(root, [root, &constituents](const ParseTree* node) {
40*993b0882SAndroid Build Coastguard Worker     switch (node->type) {
41*993b0882SAndroid Build Coastguard Worker       case ParseTree::Type::kMapping:
42*993b0882SAndroid Build Coastguard Worker         TC3_CHECK(node->IsUnaryRule());
43*993b0882SAndroid Build Coastguard Worker         constituents[static_cast<const MappingNode*>(node)->id] =
44*993b0882SAndroid Build Coastguard Worker             node->unary_rule_rhs();
45*993b0882SAndroid Build Coastguard Worker         return false;
46*993b0882SAndroid Build Coastguard Worker       case ParseTree::Type::kDefault:
47*993b0882SAndroid Build Coastguard Worker         // Continue traversal.
48*993b0882SAndroid Build Coastguard Worker         return true;
49*993b0882SAndroid Build Coastguard Worker       default:
50*993b0882SAndroid Build Coastguard Worker         // Don't continue the traversal if we are not at the root node.
51*993b0882SAndroid Build Coastguard Worker         // This could e.g. be an assertion node.
52*993b0882SAndroid Build Coastguard Worker         return (node == root);
53*993b0882SAndroid Build Coastguard Worker     }
54*993b0882SAndroid Build Coastguard Worker   });
55*993b0882SAndroid Build Coastguard Worker   return constituents;
56*993b0882SAndroid Build Coastguard Worker }
57*993b0882SAndroid Build Coastguard Worker 
58*993b0882SAndroid Build Coastguard Worker }  // namespace
59*993b0882SAndroid Build Coastguard Worker 
SemanticComposer(const reflection::Schema * semantic_values_schema)60*993b0882SAndroid Build Coastguard Worker SemanticComposer::SemanticComposer(
61*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* semantic_values_schema) {
62*993b0882SAndroid Build Coastguard Worker   evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
63*993b0882SAndroid Build Coastguard Worker                       std::make_unique<ArithmeticExpressionEvaluator>(this));
64*993b0882SAndroid Build Coastguard Worker   evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
65*993b0882SAndroid Build Coastguard Worker                       std::make_unique<ConstituentEvaluator>());
66*993b0882SAndroid Build Coastguard Worker   evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
67*993b0882SAndroid Build Coastguard Worker                       std::make_unique<ParseNumberEvaluator>(this));
68*993b0882SAndroid Build Coastguard Worker   evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
69*993b0882SAndroid Build Coastguard Worker                       std::make_unique<SpanAsStringEvaluator>());
70*993b0882SAndroid Build Coastguard Worker   if (semantic_values_schema != nullptr) {
71*993b0882SAndroid Build Coastguard Worker     // Register semantic functions.
72*993b0882SAndroid Build Coastguard Worker     evaluators_.emplace(
73*993b0882SAndroid Build Coastguard Worker         SemanticExpression_::Expression_ComposeExpression,
74*993b0882SAndroid Build Coastguard Worker         std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
75*993b0882SAndroid Build Coastguard Worker     evaluators_.emplace(
76*993b0882SAndroid Build Coastguard Worker         SemanticExpression_::Expression_ConstValueExpression,
77*993b0882SAndroid Build Coastguard Worker         std::make_unique<ConstEvaluator>(semantic_values_schema));
78*993b0882SAndroid Build Coastguard Worker     evaluators_.emplace(
79*993b0882SAndroid Build Coastguard Worker         SemanticExpression_::Expression_MergeValueExpression,
80*993b0882SAndroid Build Coastguard Worker         std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
81*993b0882SAndroid Build Coastguard Worker   }
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker 
Eval(const TextContext & text_context,const Derivation & derivation,UnsafeArena * arena) const84*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Eval(
85*993b0882SAndroid Build Coastguard Worker     const TextContext& text_context, const Derivation& derivation,
86*993b0882SAndroid Build Coastguard Worker     UnsafeArena* arena) const {
87*993b0882SAndroid Build Coastguard Worker   if (!derivation.parse_tree->IsUnaryRule() ||
88*993b0882SAndroid Build Coastguard Worker       derivation.parse_tree->unary_rule_rhs()->type !=
89*993b0882SAndroid Build Coastguard Worker           ParseTree::Type::kExpression) {
90*993b0882SAndroid Build Coastguard Worker     return nullptr;
91*993b0882SAndroid Build Coastguard Worker   }
92*993b0882SAndroid Build Coastguard Worker   return Eval(text_context,
93*993b0882SAndroid Build Coastguard Worker               static_cast<const SemanticExpressionNode*>(
94*993b0882SAndroid Build Coastguard Worker                   derivation.parse_tree->unary_rule_rhs()),
95*993b0882SAndroid Build Coastguard Worker               arena);
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker 
Eval(const TextContext & text_context,const SemanticExpressionNode * derivation,UnsafeArena * arena) const98*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Eval(
99*993b0882SAndroid Build Coastguard Worker     const TextContext& text_context, const SemanticExpressionNode* derivation,
100*993b0882SAndroid Build Coastguard Worker     UnsafeArena* arena) const {
101*993b0882SAndroid Build Coastguard Worker   // Evaluate constituents.
102*993b0882SAndroid Build Coastguard Worker   EvalContext context{&text_context, derivation};
103*993b0882SAndroid Build Coastguard Worker   for (const auto& [constituent_index, constituent] :
104*993b0882SAndroid Build Coastguard Worker        GatherConstituents(derivation)) {
105*993b0882SAndroid Build Coastguard Worker     if (constituent->type == ParseTree::Type::kExpression) {
106*993b0882SAndroid Build Coastguard Worker       TC3_ASSIGN_OR_RETURN(
107*993b0882SAndroid Build Coastguard Worker           context.rule_constituents[constituent_index],
108*993b0882SAndroid Build Coastguard Worker           Eval(text_context,
109*993b0882SAndroid Build Coastguard Worker                static_cast<const SemanticExpressionNode*>(constituent), arena));
110*993b0882SAndroid Build Coastguard Worker     } else {
111*993b0882SAndroid Build Coastguard Worker       // Just use the text of the constituent if no semantic expression was
112*993b0882SAndroid Build Coastguard Worker       // defined.
113*993b0882SAndroid Build Coastguard Worker       context.rule_constituents[constituent_index] = SemanticValue::Create(
114*993b0882SAndroid Build Coastguard Worker           text_context.Span(constituent->codepoint_span), arena);
115*993b0882SAndroid Build Coastguard Worker     }
116*993b0882SAndroid Build Coastguard Worker   }
117*993b0882SAndroid Build Coastguard Worker   return Apply(context, derivation->expression, arena);
118*993b0882SAndroid Build Coastguard Worker }
119*993b0882SAndroid Build Coastguard Worker 
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const120*993b0882SAndroid Build Coastguard Worker StatusOr<const SemanticValue*> SemanticComposer::Apply(
121*993b0882SAndroid Build Coastguard Worker     const EvalContext& context, const SemanticExpression* expression,
122*993b0882SAndroid Build Coastguard Worker     UnsafeArena* arena) const {
123*993b0882SAndroid Build Coastguard Worker   const auto handler_it = evaluators_.find(expression->expression_type());
124*993b0882SAndroid Build Coastguard Worker   if (handler_it == evaluators_.end()) {
125*993b0882SAndroid Build Coastguard Worker     return Status(StatusCode::INVALID_ARGUMENT,
126*993b0882SAndroid Build Coastguard Worker                   std::string("Unhandled expression type: ") +
127*993b0882SAndroid Build Coastguard Worker                       EnumNameExpression(expression->expression_type()));
128*993b0882SAndroid Build Coastguard Worker   }
129*993b0882SAndroid Build Coastguard Worker   return handler_it->second->Apply(context, expression, arena);
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker 
132*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3::grammar
133