xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/semantics/evaluators/compose-eval.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/semantics/evaluators/compose-eval.h"
18 
19 #include "utils/base/status_macros.h"
20 #include "utils/strings/stringpiece.h"
21 
22 namespace libtextclassifier3::grammar {
23 namespace {
24 
25 // Tries setting a singular field.
26 template <typename T>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)27 Status TrySetField(const reflection::Field* field, const SemanticValue* value,
28                    MutableFlatbuffer* result) {
29   if (!result->Set<T>(field, value->Value<T>())) {
30     return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
31   }
32   return Status::OK;
33 }
34 
35 template <>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)36 Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
37                                        const SemanticValue* value,
38                                        MutableFlatbuffer* result) {
39   auto* flatbuffer = result->Mutable(field);
40   if (flatbuffer == nullptr || !flatbuffer->MergeFrom(value->Table())) {
41     return Status(StatusCode::INVALID_ARGUMENT,
42                   "Could not set sub-field in result.");
43   }
44   return Status::OK;
45 }
46 
47 // Tries adding a value to a repeated field.
48 template <typename T>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)49 Status TryAddField(const reflection::Field* field, const SemanticValue* value,
50                    MutableFlatbuffer* result) {
51   auto* flatbuffer = result->Repeated(field);
52   if (flatbuffer == nullptr || !flatbuffer->Add(value->Value<T>())) {
53     return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
54   }
55   return Status::OK;
56 }
57 
58 template <>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)59 Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
60                                        const SemanticValue* value,
61                                        MutableFlatbuffer* result) {
62   auto* flatbuffer = result->Repeated(field);
63   auto* added = flatbuffer == nullptr ? nullptr : flatbuffer->Add();
64   if (added == nullptr || !added->MergeFrom(value->Table())) {
65     return Status(StatusCode::INVALID_ARGUMENT,
66                   "Could not add message to repeated field.");
67   }
68   return Status::OK;
69 }
70 
71 // Tries adding or setting a value for a field.
72 template <typename T>
TrySetOrAddValue(const FlatbufferFieldPath * field_path,const SemanticValue * value,MutableFlatbuffer * result)73 Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
74                         const SemanticValue* value, MutableFlatbuffer* result) {
75   MutableFlatbuffer* parent;
76   const reflection::Field* field;
77   if (!result->GetFieldWithParent(field_path, &parent, &field)) {
78     return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
79   }
80   if (field->type()->base_type() == reflection::Vector) {
81     return TryAddField<T>(field, value, parent);
82   } else {
83     return TrySetField<T>(field, value, parent);
84   }
85 }
86 
87 }  // namespace
88 
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const89 StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
90     const EvalContext& context, const SemanticExpression* expression,
91     UnsafeArena* arena) const {
92   const ComposeExpression* compose_expression =
93       expression->expression_as_ComposeExpression();
94   std::unique_ptr<MutableFlatbuffer> result =
95       semantic_value_builder_.NewTable(compose_expression->type());
96 
97   if (result == nullptr) {
98     return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
99   }
100 
101   // Evaluate and set fields.
102   if (compose_expression->fields() != nullptr) {
103     for (const ComposeExpression_::Field* field :
104          *compose_expression->fields()) {
105       // Evaluate argument.
106       TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
107                            composer_->Apply(context, field->value(), arena));
108       if (value == nullptr) {
109         continue;
110       }
111 
112       switch (value->base_type()) {
113         case reflection::BaseType::Bool: {
114           TC3_RETURN_IF_ERROR(
115               TrySetOrAddValue<bool>(field->path(), value, result.get()));
116           break;
117         }
118         case reflection::BaseType::Byte: {
119           TC3_RETURN_IF_ERROR(
120               TrySetOrAddValue<int8>(field->path(), value, result.get()));
121           break;
122         }
123         case reflection::BaseType::UByte: {
124           TC3_RETURN_IF_ERROR(
125               TrySetOrAddValue<uint8>(field->path(), value, result.get()));
126           break;
127         }
128         case reflection::BaseType::Short: {
129           TC3_RETURN_IF_ERROR(
130               TrySetOrAddValue<int16>(field->path(), value, result.get()));
131           break;
132         }
133         case reflection::BaseType::UShort: {
134           TC3_RETURN_IF_ERROR(
135               TrySetOrAddValue<uint16>(field->path(), value, result.get()));
136           break;
137         }
138         case reflection::BaseType::Int: {
139           TC3_RETURN_IF_ERROR(
140               TrySetOrAddValue<int32>(field->path(), value, result.get()));
141           break;
142         }
143         case reflection::BaseType::UInt: {
144           TC3_RETURN_IF_ERROR(
145               TrySetOrAddValue<uint32>(field->path(), value, result.get()));
146           break;
147         }
148         case reflection::BaseType::Long: {
149           TC3_RETURN_IF_ERROR(
150               TrySetOrAddValue<int64>(field->path(), value, result.get()));
151           break;
152         }
153         case reflection::BaseType::ULong: {
154           TC3_RETURN_IF_ERROR(
155               TrySetOrAddValue<uint64>(field->path(), value, result.get()));
156           break;
157         }
158         case reflection::BaseType::Float: {
159           TC3_RETURN_IF_ERROR(
160               TrySetOrAddValue<float>(field->path(), value, result.get()));
161           break;
162         }
163         case reflection::BaseType::Double: {
164           TC3_RETURN_IF_ERROR(
165               TrySetOrAddValue<double>(field->path(), value, result.get()));
166           break;
167         }
168         case reflection::BaseType::String: {
169           TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
170               field->path(), value, result.get()));
171           break;
172         }
173         case reflection::BaseType::Obj: {
174           TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
175               field->path(), value, result.get()));
176           break;
177         }
178         default:
179           return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
180       }
181     }
182   }
183 
184   return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
185 }
186 
187 }  // namespace libtextclassifier3::grammar
188