1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/semantics/evaluators/compose-eval.h"
18
19 #include "utils/base/status_macros.h"
20 #include "utils/strings/stringpiece.h"
21
22 namespace libtextclassifier3::grammar {
23 namespace {
24
25 // Tries setting a singular field.
26 template <typename T>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)27 Status TrySetField(const reflection::Field* field, const SemanticValue* value,
28 MutableFlatbuffer* result) {
29 if (!result->Set<T>(field, value->Value<T>())) {
30 return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
31 }
32 return Status::OK;
33 }
34
35 template <>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)36 Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
37 const SemanticValue* value,
38 MutableFlatbuffer* result) {
39 auto* flatbuffer = result->Mutable(field);
40 if (flatbuffer == nullptr || !flatbuffer->MergeFrom(value->Table())) {
41 return Status(StatusCode::INVALID_ARGUMENT,
42 "Could not set sub-field in result.");
43 }
44 return Status::OK;
45 }
46
47 // Tries adding a value to a repeated field.
48 template <typename T>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)49 Status TryAddField(const reflection::Field* field, const SemanticValue* value,
50 MutableFlatbuffer* result) {
51 auto* flatbuffer = result->Repeated(field);
52 if (flatbuffer == nullptr || !flatbuffer->Add(value->Value<T>())) {
53 return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
54 }
55 return Status::OK;
56 }
57
58 template <>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)59 Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
60 const SemanticValue* value,
61 MutableFlatbuffer* result) {
62 auto* flatbuffer = result->Repeated(field);
63 auto* added = flatbuffer == nullptr ? nullptr : flatbuffer->Add();
64 if (added == nullptr || !added->MergeFrom(value->Table())) {
65 return Status(StatusCode::INVALID_ARGUMENT,
66 "Could not add message to repeated field.");
67 }
68 return Status::OK;
69 }
70
71 // Tries adding or setting a value for a field.
72 template <typename T>
TrySetOrAddValue(const FlatbufferFieldPath * field_path,const SemanticValue * value,MutableFlatbuffer * result)73 Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
74 const SemanticValue* value, MutableFlatbuffer* result) {
75 MutableFlatbuffer* parent;
76 const reflection::Field* field;
77 if (!result->GetFieldWithParent(field_path, &parent, &field)) {
78 return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
79 }
80 if (field->type()->base_type() == reflection::Vector) {
81 return TryAddField<T>(field, value, parent);
82 } else {
83 return TrySetField<T>(field, value, parent);
84 }
85 }
86
87 } // namespace
88
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const89 StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
90 const EvalContext& context, const SemanticExpression* expression,
91 UnsafeArena* arena) const {
92 const ComposeExpression* compose_expression =
93 expression->expression_as_ComposeExpression();
94 std::unique_ptr<MutableFlatbuffer> result =
95 semantic_value_builder_.NewTable(compose_expression->type());
96
97 if (result == nullptr) {
98 return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
99 }
100
101 // Evaluate and set fields.
102 if (compose_expression->fields() != nullptr) {
103 for (const ComposeExpression_::Field* field :
104 *compose_expression->fields()) {
105 // Evaluate argument.
106 TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
107 composer_->Apply(context, field->value(), arena));
108 if (value == nullptr) {
109 continue;
110 }
111
112 switch (value->base_type()) {
113 case reflection::BaseType::Bool: {
114 TC3_RETURN_IF_ERROR(
115 TrySetOrAddValue<bool>(field->path(), value, result.get()));
116 break;
117 }
118 case reflection::BaseType::Byte: {
119 TC3_RETURN_IF_ERROR(
120 TrySetOrAddValue<int8>(field->path(), value, result.get()));
121 break;
122 }
123 case reflection::BaseType::UByte: {
124 TC3_RETURN_IF_ERROR(
125 TrySetOrAddValue<uint8>(field->path(), value, result.get()));
126 break;
127 }
128 case reflection::BaseType::Short: {
129 TC3_RETURN_IF_ERROR(
130 TrySetOrAddValue<int16>(field->path(), value, result.get()));
131 break;
132 }
133 case reflection::BaseType::UShort: {
134 TC3_RETURN_IF_ERROR(
135 TrySetOrAddValue<uint16>(field->path(), value, result.get()));
136 break;
137 }
138 case reflection::BaseType::Int: {
139 TC3_RETURN_IF_ERROR(
140 TrySetOrAddValue<int32>(field->path(), value, result.get()));
141 break;
142 }
143 case reflection::BaseType::UInt: {
144 TC3_RETURN_IF_ERROR(
145 TrySetOrAddValue<uint32>(field->path(), value, result.get()));
146 break;
147 }
148 case reflection::BaseType::Long: {
149 TC3_RETURN_IF_ERROR(
150 TrySetOrAddValue<int64>(field->path(), value, result.get()));
151 break;
152 }
153 case reflection::BaseType::ULong: {
154 TC3_RETURN_IF_ERROR(
155 TrySetOrAddValue<uint64>(field->path(), value, result.get()));
156 break;
157 }
158 case reflection::BaseType::Float: {
159 TC3_RETURN_IF_ERROR(
160 TrySetOrAddValue<float>(field->path(), value, result.get()));
161 break;
162 }
163 case reflection::BaseType::Double: {
164 TC3_RETURN_IF_ERROR(
165 TrySetOrAddValue<double>(field->path(), value, result.get()));
166 break;
167 }
168 case reflection::BaseType::String: {
169 TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
170 field->path(), value, result.get()));
171 break;
172 }
173 case reflection::BaseType::Obj: {
174 TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
175 field->path(), value, result.get()));
176 break;
177 }
178 default:
179 return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
180 }
181 }
182 }
183
184 return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
185 }
186
187 } // namespace libtextclassifier3::grammar
188