1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/testing/annotator.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
20*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/reflection.h"
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
23*993b0882SAndroid Build Coastguard Worker
FirstResult(const std::vector<ClassificationResult> & results)24*993b0882SAndroid Build Coastguard Worker std::string FirstResult(const std::vector<ClassificationResult>& results) {
25*993b0882SAndroid Build Coastguard Worker if (results.empty()) {
26*993b0882SAndroid Build Coastguard Worker return "<INVALID RESULTS>";
27*993b0882SAndroid Build Coastguard Worker }
28*993b0882SAndroid Build Coastguard Worker return results[0].collection;
29*993b0882SAndroid Build Coastguard Worker }
30*993b0882SAndroid Build Coastguard Worker
ReadFile(const std::string & file_name)31*993b0882SAndroid Build Coastguard Worker std::string ReadFile(const std::string& file_name) {
32*993b0882SAndroid Build Coastguard Worker std::ifstream file_stream(file_name);
33*993b0882SAndroid Build Coastguard Worker return std::string(std::istreambuf_iterator<char>(file_stream), {});
34*993b0882SAndroid Build Coastguard Worker }
35*993b0882SAndroid Build Coastguard Worker
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score,const float priority_score)36*993b0882SAndroid Build Coastguard Worker std::unique_ptr<RegexModel_::PatternT> MakePattern(
37*993b0882SAndroid Build Coastguard Worker const std::string& collection_name, const std::string& pattern,
38*993b0882SAndroid Build Coastguard Worker const bool enabled_for_classification, const bool enabled_for_selection,
39*993b0882SAndroid Build Coastguard Worker const bool enabled_for_annotation, const float score,
40*993b0882SAndroid Build Coastguard Worker const float priority_score) {
41*993b0882SAndroid Build Coastguard Worker std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
42*993b0882SAndroid Build Coastguard Worker result->collection_name = collection_name;
43*993b0882SAndroid Build Coastguard Worker result->pattern = pattern;
44*993b0882SAndroid Build Coastguard Worker // We cannot directly operate with |= on the flag, so use an int here.
45*993b0882SAndroid Build Coastguard Worker int enabled_modes = ModeFlag_NONE;
46*993b0882SAndroid Build Coastguard Worker if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
47*993b0882SAndroid Build Coastguard Worker if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
48*993b0882SAndroid Build Coastguard Worker if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
49*993b0882SAndroid Build Coastguard Worker result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
50*993b0882SAndroid Build Coastguard Worker result->target_classification_score = score;
51*993b0882SAndroid Build Coastguard Worker result->priority_score = priority_score;
52*993b0882SAndroid Build Coastguard Worker return result;
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker
55*993b0882SAndroid Build Coastguard Worker // Shortcut function that doesn't need to specify the priority score.
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score)56*993b0882SAndroid Build Coastguard Worker std::unique_ptr<RegexModel_::PatternT> MakePattern(
57*993b0882SAndroid Build Coastguard Worker const std::string& collection_name, const std::string& pattern,
58*993b0882SAndroid Build Coastguard Worker const bool enabled_for_classification, const bool enabled_for_selection,
59*993b0882SAndroid Build Coastguard Worker const bool enabled_for_annotation, const float score) {
60*993b0882SAndroid Build Coastguard Worker return MakePattern(collection_name, pattern, enabled_for_classification,
61*993b0882SAndroid Build Coastguard Worker enabled_for_selection, enabled_for_annotation,
62*993b0882SAndroid Build Coastguard Worker /*score=*/score,
63*993b0882SAndroid Build Coastguard Worker /*priority_score=*/score);
64*993b0882SAndroid Build Coastguard Worker }
65*993b0882SAndroid Build Coastguard Worker
AddTestRegexModel(ModelT * unpacked_model)66*993b0882SAndroid Build Coastguard Worker void AddTestRegexModel(ModelT* unpacked_model) {
67*993b0882SAndroid Build Coastguard Worker // Add test regex models.
68*993b0882SAndroid Build Coastguard Worker unpacked_model->regex_model->patterns.push_back(MakePattern(
69*993b0882SAndroid Build Coastguard Worker "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
70*993b0882SAndroid Build Coastguard Worker /*enabled_for_classification=*/true,
71*993b0882SAndroid Build Coastguard Worker /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
72*993b0882SAndroid Build Coastguard Worker
73*993b0882SAndroid Build Coastguard Worker // Use meta data to generate custom serialized entity data.
74*993b0882SAndroid Build Coastguard Worker MutableFlatbufferBuilder entity_data_builder(
75*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<reflection::Schema>(
76*993b0882SAndroid Build Coastguard Worker unpacked_model->entity_data_schema.data()));
77*993b0882SAndroid Build Coastguard Worker RegexModel_::PatternT* pattern =
78*993b0882SAndroid Build Coastguard Worker unpacked_model->regex_model->patterns.back().get();
79*993b0882SAndroid Build Coastguard Worker
80*993b0882SAndroid Build Coastguard Worker {
81*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
82*993b0882SAndroid Build Coastguard Worker entity_data_builder.NewRoot();
83*993b0882SAndroid Build Coastguard Worker entity_data->Set("is_alive", true);
84*993b0882SAndroid Build Coastguard Worker pattern->serialized_entity_data = entity_data->Serialize();
85*993b0882SAndroid Build Coastguard Worker }
86*993b0882SAndroid Build Coastguard Worker pattern->capturing_group.emplace_back(new CapturingGroupT);
87*993b0882SAndroid Build Coastguard Worker pattern->capturing_group.emplace_back(new CapturingGroupT);
88*993b0882SAndroid Build Coastguard Worker pattern->capturing_group.emplace_back(new CapturingGroupT);
89*993b0882SAndroid Build Coastguard Worker pattern->capturing_group.emplace_back(new CapturingGroupT);
90*993b0882SAndroid Build Coastguard Worker // Group 0 is the full match, capturing groups starting at 1.
91*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[1]->entity_field_path.reset(
92*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldPathT);
93*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[1]->entity_field_path->field.emplace_back(
94*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldT);
95*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
96*993b0882SAndroid Build Coastguard Worker "first_name";
97*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[2]->entity_field_path.reset(
98*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldPathT);
99*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[2]->entity_field_path->field.emplace_back(
100*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldT);
101*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
102*993b0882SAndroid Build Coastguard Worker "last_name";
103*993b0882SAndroid Build Coastguard Worker // Set `former_us_president` field if we match Obama.
104*993b0882SAndroid Build Coastguard Worker {
105*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
106*993b0882SAndroid Build Coastguard Worker entity_data_builder.NewRoot();
107*993b0882SAndroid Build Coastguard Worker entity_data->Set("former_us_president", true);
108*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[2]->serialized_entity_data =
109*993b0882SAndroid Build Coastguard Worker entity_data->Serialize();
110*993b0882SAndroid Build Coastguard Worker }
111*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[3]->entity_field_path.reset(
112*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldPathT);
113*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[3]->entity_field_path->field.emplace_back(
114*993b0882SAndroid Build Coastguard Worker new FlatbufferFieldT);
115*993b0882SAndroid Build Coastguard Worker pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
116*993b0882SAndroid Build Coastguard Worker "age";
117*993b0882SAndroid Build Coastguard Worker }
118*993b0882SAndroid Build Coastguard Worker
CreateEmptyModel(const std::function<void (ModelT * model)> model_update_fn)119*993b0882SAndroid Build Coastguard Worker std::string CreateEmptyModel(
120*993b0882SAndroid Build Coastguard Worker const std::function<void(ModelT* model)> model_update_fn) {
121*993b0882SAndroid Build Coastguard Worker ModelT model;
122*993b0882SAndroid Build Coastguard Worker model_update_fn(&model);
123*993b0882SAndroid Build Coastguard Worker
124*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
125*993b0882SAndroid Build Coastguard Worker FinishModelBuffer(builder, Model::Pack(builder, &model));
126*993b0882SAndroid Build Coastguard Worker return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
127*993b0882SAndroid Build Coastguard Worker builder.GetSize());
128*993b0882SAndroid Build Coastguard Worker }
129*993b0882SAndroid Build Coastguard Worker
130*993b0882SAndroid Build Coastguard Worker // Create fake entity data schema meta data.
AddTestEntitySchemaData(ModelT * unpacked_model)131*993b0882SAndroid Build Coastguard Worker void AddTestEntitySchemaData(ModelT* unpacked_model) {
132*993b0882SAndroid Build Coastguard Worker // Cannot use object oriented API here as that is not available for the
133*993b0882SAndroid Build Coastguard Worker // reflection schema.
134*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder schema_builder;
135*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Field>> fields = {
136*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
137*993b0882SAndroid Build Coastguard Worker schema_builder,
138*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("first_name"),
139*993b0882SAndroid Build Coastguard Worker /*type=*/
140*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
141*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::String),
142*993b0882SAndroid Build Coastguard Worker /*id=*/0,
143*993b0882SAndroid Build Coastguard Worker /*offset=*/4),
144*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
145*993b0882SAndroid Build Coastguard Worker schema_builder,
146*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("is_alive"),
147*993b0882SAndroid Build Coastguard Worker /*type=*/
148*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
149*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::Bool),
150*993b0882SAndroid Build Coastguard Worker /*id=*/1,
151*993b0882SAndroid Build Coastguard Worker /*offset=*/6),
152*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
153*993b0882SAndroid Build Coastguard Worker schema_builder,
154*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("last_name"),
155*993b0882SAndroid Build Coastguard Worker /*type=*/
156*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
157*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::String),
158*993b0882SAndroid Build Coastguard Worker /*id=*/2,
159*993b0882SAndroid Build Coastguard Worker /*offset=*/8),
160*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
161*993b0882SAndroid Build Coastguard Worker schema_builder,
162*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("age"),
163*993b0882SAndroid Build Coastguard Worker /*type=*/
164*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
165*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::Int),
166*993b0882SAndroid Build Coastguard Worker /*id=*/3,
167*993b0882SAndroid Build Coastguard Worker /*offset=*/10),
168*993b0882SAndroid Build Coastguard Worker reflection::CreateField(
169*993b0882SAndroid Build Coastguard Worker schema_builder,
170*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("former_us_president"),
171*993b0882SAndroid Build Coastguard Worker /*type=*/
172*993b0882SAndroid Build Coastguard Worker reflection::CreateType(schema_builder,
173*993b0882SAndroid Build Coastguard Worker /*base_type=*/reflection::Bool),
174*993b0882SAndroid Build Coastguard Worker /*id=*/4,
175*993b0882SAndroid Build Coastguard Worker /*offset=*/12)};
176*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Enum>> enums;
177*993b0882SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<reflection::Object>> objects = {
178*993b0882SAndroid Build Coastguard Worker reflection::CreateObject(
179*993b0882SAndroid Build Coastguard Worker schema_builder,
180*993b0882SAndroid Build Coastguard Worker /*name=*/schema_builder.CreateString("EntityData"),
181*993b0882SAndroid Build Coastguard Worker /*fields=*/
182*993b0882SAndroid Build Coastguard Worker schema_builder.CreateVectorOfSortedTables(&fields))};
183*993b0882SAndroid Build Coastguard Worker schema_builder.Finish(reflection::CreateSchema(
184*993b0882SAndroid Build Coastguard Worker schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
185*993b0882SAndroid Build Coastguard Worker schema_builder.CreateVectorOfSortedTables(&enums),
186*993b0882SAndroid Build Coastguard Worker /*(unused) file_ident=*/0,
187*993b0882SAndroid Build Coastguard Worker /*(unused) file_ext=*/0,
188*993b0882SAndroid Build Coastguard Worker /*root_table*/ objects[0]));
189*993b0882SAndroid Build Coastguard Worker
190*993b0882SAndroid Build Coastguard Worker unpacked_model->entity_data_schema.assign(
191*993b0882SAndroid Build Coastguard Worker schema_builder.GetBufferPointer(),
192*993b0882SAndroid Build Coastguard Worker schema_builder.GetBufferPointer() + schema_builder.GetSize());
193*993b0882SAndroid Build Coastguard Worker }
194*993b0882SAndroid Build Coastguard Worker
MakeAnnotatedSpan(CodepointSpan span,const std::string & collection,const float score,AnnotatedSpan::Source source)195*993b0882SAndroid Build Coastguard Worker AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
196*993b0882SAndroid Build Coastguard Worker const std::string& collection,
197*993b0882SAndroid Build Coastguard Worker const float score,
198*993b0882SAndroid Build Coastguard Worker AnnotatedSpan::Source source) {
199*993b0882SAndroid Build Coastguard Worker AnnotatedSpan result;
200*993b0882SAndroid Build Coastguard Worker result.span = span;
201*993b0882SAndroid Build Coastguard Worker result.classification.push_back({collection, score});
202*993b0882SAndroid Build Coastguard Worker result.source = source;
203*993b0882SAndroid Build Coastguard Worker return result;
204*993b0882SAndroid Build Coastguard Worker }
205*993b0882SAndroid Build Coastguard Worker
206*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
207