xref: /aosp_15_r20/external/libtextclassifier/native/utils/testing/annotator.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/testing/annotator.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
20*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/reflection.h"
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
23*993b0882SAndroid Build Coastguard Worker 
FirstResult(const std::vector<ClassificationResult> & results)24*993b0882SAndroid Build Coastguard Worker std::string FirstResult(const std::vector<ClassificationResult>& results) {
25*993b0882SAndroid Build Coastguard Worker   if (results.empty()) {
26*993b0882SAndroid Build Coastguard Worker     return "<INVALID RESULTS>";
27*993b0882SAndroid Build Coastguard Worker   }
28*993b0882SAndroid Build Coastguard Worker   return results[0].collection;
29*993b0882SAndroid Build Coastguard Worker }
30*993b0882SAndroid Build Coastguard Worker 
ReadFile(const std::string & file_name)31*993b0882SAndroid Build Coastguard Worker std::string ReadFile(const std::string& file_name) {
32*993b0882SAndroid Build Coastguard Worker   std::ifstream file_stream(file_name);
33*993b0882SAndroid Build Coastguard Worker   return std::string(std::istreambuf_iterator<char>(file_stream), {});
34*993b0882SAndroid Build Coastguard Worker }
35*993b0882SAndroid Build Coastguard Worker 
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score,const float priority_score)36*993b0882SAndroid Build Coastguard Worker std::unique_ptr<RegexModel_::PatternT> MakePattern(
37*993b0882SAndroid Build Coastguard Worker     const std::string& collection_name, const std::string& pattern,
38*993b0882SAndroid Build Coastguard Worker     const bool enabled_for_classification, const bool enabled_for_selection,
39*993b0882SAndroid Build Coastguard Worker     const bool enabled_for_annotation, const float score,
40*993b0882SAndroid Build Coastguard Worker     const float priority_score) {
41*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
42*993b0882SAndroid Build Coastguard Worker   result->collection_name = collection_name;
43*993b0882SAndroid Build Coastguard Worker   result->pattern = pattern;
44*993b0882SAndroid Build Coastguard Worker   // We cannot directly operate with |= on the flag, so use an int here.
45*993b0882SAndroid Build Coastguard Worker   int enabled_modes = ModeFlag_NONE;
46*993b0882SAndroid Build Coastguard Worker   if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
47*993b0882SAndroid Build Coastguard Worker   if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
48*993b0882SAndroid Build Coastguard Worker   if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
49*993b0882SAndroid Build Coastguard Worker   result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
50*993b0882SAndroid Build Coastguard Worker   result->target_classification_score = score;
51*993b0882SAndroid Build Coastguard Worker   result->priority_score = priority_score;
52*993b0882SAndroid Build Coastguard Worker   return result;
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker // Shortcut function that doesn't need to specify the priority score.
MakePattern(const std::string & collection_name,const std::string & pattern,const bool enabled_for_classification,const bool enabled_for_selection,const bool enabled_for_annotation,const float score)56*993b0882SAndroid Build Coastguard Worker std::unique_ptr<RegexModel_::PatternT> MakePattern(
57*993b0882SAndroid Build Coastguard Worker     const std::string& collection_name, const std::string& pattern,
58*993b0882SAndroid Build Coastguard Worker     const bool enabled_for_classification, const bool enabled_for_selection,
59*993b0882SAndroid Build Coastguard Worker     const bool enabled_for_annotation, const float score) {
60*993b0882SAndroid Build Coastguard Worker   return MakePattern(collection_name, pattern, enabled_for_classification,
61*993b0882SAndroid Build Coastguard Worker                      enabled_for_selection, enabled_for_annotation,
62*993b0882SAndroid Build Coastguard Worker                      /*score=*/score,
63*993b0882SAndroid Build Coastguard Worker                      /*priority_score=*/score);
64*993b0882SAndroid Build Coastguard Worker }
65*993b0882SAndroid Build Coastguard Worker 
AddTestRegexModel(ModelT * unpacked_model)66*993b0882SAndroid Build Coastguard Worker void AddTestRegexModel(ModelT* unpacked_model) {
67*993b0882SAndroid Build Coastguard Worker   // Add test regex models.
68*993b0882SAndroid Build Coastguard Worker   unpacked_model->regex_model->patterns.push_back(MakePattern(
69*993b0882SAndroid Build Coastguard Worker       "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
70*993b0882SAndroid Build Coastguard Worker       /*enabled_for_classification=*/true,
71*993b0882SAndroid Build Coastguard Worker       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker   // Use meta data to generate custom serialized entity data.
74*993b0882SAndroid Build Coastguard Worker   MutableFlatbufferBuilder entity_data_builder(
75*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<reflection::Schema>(
76*993b0882SAndroid Build Coastguard Worker           unpacked_model->entity_data_schema.data()));
77*993b0882SAndroid Build Coastguard Worker   RegexModel_::PatternT* pattern =
78*993b0882SAndroid Build Coastguard Worker       unpacked_model->regex_model->patterns.back().get();
79*993b0882SAndroid Build Coastguard Worker 
80*993b0882SAndroid Build Coastguard Worker   {
81*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<MutableFlatbuffer> entity_data =
82*993b0882SAndroid Build Coastguard Worker         entity_data_builder.NewRoot();
83*993b0882SAndroid Build Coastguard Worker     entity_data->Set("is_alive", true);
84*993b0882SAndroid Build Coastguard Worker     pattern->serialized_entity_data = entity_data->Serialize();
85*993b0882SAndroid Build Coastguard Worker   }
86*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group.emplace_back(new CapturingGroupT);
87*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group.emplace_back(new CapturingGroupT);
88*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group.emplace_back(new CapturingGroupT);
89*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group.emplace_back(new CapturingGroupT);
90*993b0882SAndroid Build Coastguard Worker   // Group 0 is the full match, capturing groups starting at 1.
91*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[1]->entity_field_path.reset(
92*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldPathT);
93*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[1]->entity_field_path->field.emplace_back(
94*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldT);
95*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
96*993b0882SAndroid Build Coastguard Worker       "first_name";
97*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[2]->entity_field_path.reset(
98*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldPathT);
99*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[2]->entity_field_path->field.emplace_back(
100*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldT);
101*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
102*993b0882SAndroid Build Coastguard Worker       "last_name";
103*993b0882SAndroid Build Coastguard Worker   // Set `former_us_president` field if we match Obama.
104*993b0882SAndroid Build Coastguard Worker   {
105*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<MutableFlatbuffer> entity_data =
106*993b0882SAndroid Build Coastguard Worker         entity_data_builder.NewRoot();
107*993b0882SAndroid Build Coastguard Worker     entity_data->Set("former_us_president", true);
108*993b0882SAndroid Build Coastguard Worker     pattern->capturing_group[2]->serialized_entity_data =
109*993b0882SAndroid Build Coastguard Worker         entity_data->Serialize();
110*993b0882SAndroid Build Coastguard Worker   }
111*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[3]->entity_field_path.reset(
112*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldPathT);
113*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[3]->entity_field_path->field.emplace_back(
114*993b0882SAndroid Build Coastguard Worker       new FlatbufferFieldT);
115*993b0882SAndroid Build Coastguard Worker   pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
116*993b0882SAndroid Build Coastguard Worker       "age";
117*993b0882SAndroid Build Coastguard Worker }
118*993b0882SAndroid Build Coastguard Worker 
CreateEmptyModel(const std::function<void (ModelT * model)> model_update_fn)119*993b0882SAndroid Build Coastguard Worker std::string CreateEmptyModel(
120*993b0882SAndroid Build Coastguard Worker     const std::function<void(ModelT* model)> model_update_fn) {
121*993b0882SAndroid Build Coastguard Worker   ModelT model;
122*993b0882SAndroid Build Coastguard Worker   model_update_fn(&model);
123*993b0882SAndroid Build Coastguard Worker 
124*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
125*993b0882SAndroid Build Coastguard Worker   FinishModelBuffer(builder, Model::Pack(builder, &model));
126*993b0882SAndroid Build Coastguard Worker   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
127*993b0882SAndroid Build Coastguard Worker                      builder.GetSize());
128*993b0882SAndroid Build Coastguard Worker }
129*993b0882SAndroid Build Coastguard Worker 
130*993b0882SAndroid Build Coastguard Worker // Create fake entity data schema meta data.
AddTestEntitySchemaData(ModelT * unpacked_model)131*993b0882SAndroid Build Coastguard Worker void AddTestEntitySchemaData(ModelT* unpacked_model) {
132*993b0882SAndroid Build Coastguard Worker   // Cannot use object oriented API here as that is not available for the
133*993b0882SAndroid Build Coastguard Worker   // reflection schema.
134*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder schema_builder;
135*993b0882SAndroid Build Coastguard Worker   std::vector<flatbuffers::Offset<reflection::Field>> fields = {
136*993b0882SAndroid Build Coastguard Worker       reflection::CreateField(
137*993b0882SAndroid Build Coastguard Worker           schema_builder,
138*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("first_name"),
139*993b0882SAndroid Build Coastguard Worker           /*type=*/
140*993b0882SAndroid Build Coastguard Worker           reflection::CreateType(schema_builder,
141*993b0882SAndroid Build Coastguard Worker                                  /*base_type=*/reflection::String),
142*993b0882SAndroid Build Coastguard Worker           /*id=*/0,
143*993b0882SAndroid Build Coastguard Worker           /*offset=*/4),
144*993b0882SAndroid Build Coastguard Worker       reflection::CreateField(
145*993b0882SAndroid Build Coastguard Worker           schema_builder,
146*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("is_alive"),
147*993b0882SAndroid Build Coastguard Worker           /*type=*/
148*993b0882SAndroid Build Coastguard Worker           reflection::CreateType(schema_builder,
149*993b0882SAndroid Build Coastguard Worker                                  /*base_type=*/reflection::Bool),
150*993b0882SAndroid Build Coastguard Worker           /*id=*/1,
151*993b0882SAndroid Build Coastguard Worker           /*offset=*/6),
152*993b0882SAndroid Build Coastguard Worker       reflection::CreateField(
153*993b0882SAndroid Build Coastguard Worker           schema_builder,
154*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("last_name"),
155*993b0882SAndroid Build Coastguard Worker           /*type=*/
156*993b0882SAndroid Build Coastguard Worker           reflection::CreateType(schema_builder,
157*993b0882SAndroid Build Coastguard Worker                                  /*base_type=*/reflection::String),
158*993b0882SAndroid Build Coastguard Worker           /*id=*/2,
159*993b0882SAndroid Build Coastguard Worker           /*offset=*/8),
160*993b0882SAndroid Build Coastguard Worker       reflection::CreateField(
161*993b0882SAndroid Build Coastguard Worker           schema_builder,
162*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("age"),
163*993b0882SAndroid Build Coastguard Worker           /*type=*/
164*993b0882SAndroid Build Coastguard Worker           reflection::CreateType(schema_builder,
165*993b0882SAndroid Build Coastguard Worker                                  /*base_type=*/reflection::Int),
166*993b0882SAndroid Build Coastguard Worker           /*id=*/3,
167*993b0882SAndroid Build Coastguard Worker           /*offset=*/10),
168*993b0882SAndroid Build Coastguard Worker       reflection::CreateField(
169*993b0882SAndroid Build Coastguard Worker           schema_builder,
170*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("former_us_president"),
171*993b0882SAndroid Build Coastguard Worker           /*type=*/
172*993b0882SAndroid Build Coastguard Worker           reflection::CreateType(schema_builder,
173*993b0882SAndroid Build Coastguard Worker                                  /*base_type=*/reflection::Bool),
174*993b0882SAndroid Build Coastguard Worker           /*id=*/4,
175*993b0882SAndroid Build Coastguard Worker           /*offset=*/12)};
176*993b0882SAndroid Build Coastguard Worker   std::vector<flatbuffers::Offset<reflection::Enum>> enums;
177*993b0882SAndroid Build Coastguard Worker   std::vector<flatbuffers::Offset<reflection::Object>> objects = {
178*993b0882SAndroid Build Coastguard Worker       reflection::CreateObject(
179*993b0882SAndroid Build Coastguard Worker           schema_builder,
180*993b0882SAndroid Build Coastguard Worker           /*name=*/schema_builder.CreateString("EntityData"),
181*993b0882SAndroid Build Coastguard Worker           /*fields=*/
182*993b0882SAndroid Build Coastguard Worker           schema_builder.CreateVectorOfSortedTables(&fields))};
183*993b0882SAndroid Build Coastguard Worker   schema_builder.Finish(reflection::CreateSchema(
184*993b0882SAndroid Build Coastguard Worker       schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
185*993b0882SAndroid Build Coastguard Worker       schema_builder.CreateVectorOfSortedTables(&enums),
186*993b0882SAndroid Build Coastguard Worker       /*(unused) file_ident=*/0,
187*993b0882SAndroid Build Coastguard Worker       /*(unused) file_ext=*/0,
188*993b0882SAndroid Build Coastguard Worker       /*root_table*/ objects[0]));
189*993b0882SAndroid Build Coastguard Worker 
190*993b0882SAndroid Build Coastguard Worker   unpacked_model->entity_data_schema.assign(
191*993b0882SAndroid Build Coastguard Worker       schema_builder.GetBufferPointer(),
192*993b0882SAndroid Build Coastguard Worker       schema_builder.GetBufferPointer() + schema_builder.GetSize());
193*993b0882SAndroid Build Coastguard Worker }
194*993b0882SAndroid Build Coastguard Worker 
MakeAnnotatedSpan(CodepointSpan span,const std::string & collection,const float score,AnnotatedSpan::Source source)195*993b0882SAndroid Build Coastguard Worker AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
196*993b0882SAndroid Build Coastguard Worker                                 const std::string& collection,
197*993b0882SAndroid Build Coastguard Worker                                 const float score,
198*993b0882SAndroid Build Coastguard Worker                                 AnnotatedSpan::Source source) {
199*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan result;
200*993b0882SAndroid Build Coastguard Worker   result.span = span;
201*993b0882SAndroid Build Coastguard Worker   result.classification.push_back({collection, score});
202*993b0882SAndroid Build Coastguard Worker   result.source = source;
203*993b0882SAndroid Build Coastguard Worker   return result;
204*993b0882SAndroid Build Coastguard Worker }
205*993b0882SAndroid Build Coastguard Worker 
206*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
207