xref: /aosp_15_r20/external/libtextclassifier/native/annotator/cached-features.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "annotator/cached-features.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h"
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace {
25*993b0882SAndroid Build Coastguard Worker 
CalculateOutputFeaturesSize(const FeatureProcessorOptions * options,int feature_vector_size)26*993b0882SAndroid Build Coastguard Worker int CalculateOutputFeaturesSize(const FeatureProcessorOptions* options,
27*993b0882SAndroid Build Coastguard Worker                                 int feature_vector_size) {
28*993b0882SAndroid Build Coastguard Worker   const bool bounds_sensitive_enabled =
29*993b0882SAndroid Build Coastguard Worker       options->bounds_sensitive_features() &&
30*993b0882SAndroid Build Coastguard Worker       options->bounds_sensitive_features()->enabled();
31*993b0882SAndroid Build Coastguard Worker 
32*993b0882SAndroid Build Coastguard Worker   int num_extracted_tokens = 0;
33*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_enabled) {
34*993b0882SAndroid Build Coastguard Worker     const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
35*993b0882SAndroid Build Coastguard Worker         options->bounds_sensitive_features();
36*993b0882SAndroid Build Coastguard Worker     num_extracted_tokens += config->num_tokens_before();
37*993b0882SAndroid Build Coastguard Worker     num_extracted_tokens += config->num_tokens_inside_left();
38*993b0882SAndroid Build Coastguard Worker     num_extracted_tokens += config->num_tokens_inside_right();
39*993b0882SAndroid Build Coastguard Worker     num_extracted_tokens += config->num_tokens_after();
40*993b0882SAndroid Build Coastguard Worker     if (config->include_inside_bag()) {
41*993b0882SAndroid Build Coastguard Worker       ++num_extracted_tokens;
42*993b0882SAndroid Build Coastguard Worker     }
43*993b0882SAndroid Build Coastguard Worker   } else {
44*993b0882SAndroid Build Coastguard Worker     num_extracted_tokens = 2 * options->context_size() + 1;
45*993b0882SAndroid Build Coastguard Worker   }
46*993b0882SAndroid Build Coastguard Worker 
47*993b0882SAndroid Build Coastguard Worker   int output_features_size = num_extracted_tokens * feature_vector_size;
48*993b0882SAndroid Build Coastguard Worker 
49*993b0882SAndroid Build Coastguard Worker   if (bounds_sensitive_enabled &&
50*993b0882SAndroid Build Coastguard Worker       options->bounds_sensitive_features()->include_inside_length()) {
51*993b0882SAndroid Build Coastguard Worker     ++output_features_size;
52*993b0882SAndroid Build Coastguard Worker   }
53*993b0882SAndroid Build Coastguard Worker 
54*993b0882SAndroid Build Coastguard Worker   return output_features_size;
55*993b0882SAndroid Build Coastguard Worker }
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker }  // namespace
58*993b0882SAndroid Build Coastguard Worker 
Create(const TokenSpan & extraction_span,std::unique_ptr<std::vector<float>> features,std::unique_ptr<std::vector<float>> padding_features,const FeatureProcessorOptions * options,int feature_vector_size)59*993b0882SAndroid Build Coastguard Worker std::unique_ptr<CachedFeatures> CachedFeatures::Create(
60*993b0882SAndroid Build Coastguard Worker     const TokenSpan& extraction_span,
61*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<std::vector<float>> features,
62*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<std::vector<float>> padding_features,
63*993b0882SAndroid Build Coastguard Worker     const FeatureProcessorOptions* options, int feature_vector_size) {
64*993b0882SAndroid Build Coastguard Worker   const int min_feature_version =
65*993b0882SAndroid Build Coastguard Worker       options->bounds_sensitive_features() &&
66*993b0882SAndroid Build Coastguard Worker               options->bounds_sensitive_features()->enabled()
67*993b0882SAndroid Build Coastguard Worker           ? 2
68*993b0882SAndroid Build Coastguard Worker           : 1;
69*993b0882SAndroid Build Coastguard Worker   if (options->feature_version() < min_feature_version) {
70*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Unsupported feature version.";
71*993b0882SAndroid Build Coastguard Worker     return nullptr;
72*993b0882SAndroid Build Coastguard Worker   }
73*993b0882SAndroid Build Coastguard Worker 
74*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<CachedFeatures> cached_features(new CachedFeatures());
75*993b0882SAndroid Build Coastguard Worker   cached_features->extraction_span_ = extraction_span;
76*993b0882SAndroid Build Coastguard Worker   cached_features->features_ = std::move(features);
77*993b0882SAndroid Build Coastguard Worker   cached_features->padding_features_ = std::move(padding_features);
78*993b0882SAndroid Build Coastguard Worker   cached_features->options_ = options;
79*993b0882SAndroid Build Coastguard Worker 
80*993b0882SAndroid Build Coastguard Worker   cached_features->output_features_size_ =
81*993b0882SAndroid Build Coastguard Worker       CalculateOutputFeaturesSize(options, feature_vector_size);
82*993b0882SAndroid Build Coastguard Worker 
83*993b0882SAndroid Build Coastguard Worker   return cached_features;
84*993b0882SAndroid Build Coastguard Worker }
85*993b0882SAndroid Build Coastguard Worker 
AppendClickContextFeaturesForClick(int click_pos,std::vector<float> * output_features) const86*993b0882SAndroid Build Coastguard Worker void CachedFeatures::AppendClickContextFeaturesForClick(
87*993b0882SAndroid Build Coastguard Worker     int click_pos, std::vector<float>* output_features) const {
88*993b0882SAndroid Build Coastguard Worker   click_pos -= extraction_span_.first;
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker   AppendFeaturesInternal(
91*993b0882SAndroid Build Coastguard Worker       /*intended_span=*/TokenSpan(click_pos).Expand(options_->context_size(),
92*993b0882SAndroid Build Coastguard Worker                                                     options_->context_size()),
93*993b0882SAndroid Build Coastguard Worker       /*read_mask_span=*/{0, extraction_span_.Size()}, output_features);
94*993b0882SAndroid Build Coastguard Worker }
95*993b0882SAndroid Build Coastguard Worker 
AppendBoundsSensitiveFeaturesForSpan(TokenSpan selected_span,std::vector<float> * output_features) const96*993b0882SAndroid Build Coastguard Worker void CachedFeatures::AppendBoundsSensitiveFeaturesForSpan(
97*993b0882SAndroid Build Coastguard Worker     TokenSpan selected_span, std::vector<float>* output_features) const {
98*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions_::BoundsSensitiveFeatures* config =
99*993b0882SAndroid Build Coastguard Worker       options_->bounds_sensitive_features();
100*993b0882SAndroid Build Coastguard Worker 
101*993b0882SAndroid Build Coastguard Worker   selected_span.first -= extraction_span_.first;
102*993b0882SAndroid Build Coastguard Worker   selected_span.second -= extraction_span_.first;
103*993b0882SAndroid Build Coastguard Worker 
104*993b0882SAndroid Build Coastguard Worker   // Append the features for tokens around the left bound. Masks out tokens
105*993b0882SAndroid Build Coastguard Worker   // after the right bound, so that if num_tokens_inside_left goes past it,
106*993b0882SAndroid Build Coastguard Worker   // padding tokens will be used.
107*993b0882SAndroid Build Coastguard Worker   AppendFeaturesInternal(
108*993b0882SAndroid Build Coastguard Worker       /*intended_span=*/{selected_span.first - config->num_tokens_before(),
109*993b0882SAndroid Build Coastguard Worker                          selected_span.first +
110*993b0882SAndroid Build Coastguard Worker                              config->num_tokens_inside_left()},
111*993b0882SAndroid Build Coastguard Worker       /*read_mask_span=*/{0, selected_span.second}, output_features);
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker   // Append the features for tokens around the right bound. Masks out tokens
114*993b0882SAndroid Build Coastguard Worker   // before the left bound, so that if num_tokens_inside_right goes past it,
115*993b0882SAndroid Build Coastguard Worker   // padding tokens will be used.
116*993b0882SAndroid Build Coastguard Worker   AppendFeaturesInternal(
117*993b0882SAndroid Build Coastguard Worker       /*intended_span=*/{selected_span.second -
118*993b0882SAndroid Build Coastguard Worker                              config->num_tokens_inside_right(),
119*993b0882SAndroid Build Coastguard Worker                          selected_span.second + config->num_tokens_after()},
120*993b0882SAndroid Build Coastguard Worker       /*read_mask_span=*/
121*993b0882SAndroid Build Coastguard Worker       {selected_span.first, extraction_span_.Size()}, output_features);
122*993b0882SAndroid Build Coastguard Worker 
123*993b0882SAndroid Build Coastguard Worker   if (config->include_inside_bag()) {
124*993b0882SAndroid Build Coastguard Worker     AppendBagFeatures(selected_span, output_features);
125*993b0882SAndroid Build Coastguard Worker   }
126*993b0882SAndroid Build Coastguard Worker 
127*993b0882SAndroid Build Coastguard Worker   if (config->include_inside_length()) {
128*993b0882SAndroid Build Coastguard Worker     output_features->push_back(static_cast<float>(selected_span.Size()));
129*993b0882SAndroid Build Coastguard Worker   }
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker 
AppendFeaturesInternal(const TokenSpan & intended_span,const TokenSpan & read_mask_span,std::vector<float> * output_features) const132*993b0882SAndroid Build Coastguard Worker void CachedFeatures::AppendFeaturesInternal(
133*993b0882SAndroid Build Coastguard Worker     const TokenSpan& intended_span, const TokenSpan& read_mask_span,
134*993b0882SAndroid Build Coastguard Worker     std::vector<float>* output_features) const {
135*993b0882SAndroid Build Coastguard Worker   const TokenSpan copy_span =
136*993b0882SAndroid Build Coastguard Worker       IntersectTokenSpans(intended_span, read_mask_span);
137*993b0882SAndroid Build Coastguard Worker   for (int i = intended_span.first; i < copy_span.first; ++i) {
138*993b0882SAndroid Build Coastguard Worker     AppendPaddingFeatures(output_features);
139*993b0882SAndroid Build Coastguard Worker   }
140*993b0882SAndroid Build Coastguard Worker   output_features->insert(
141*993b0882SAndroid Build Coastguard Worker       output_features->end(),
142*993b0882SAndroid Build Coastguard Worker       features_->begin() + copy_span.first * NumFeaturesPerToken(),
143*993b0882SAndroid Build Coastguard Worker       features_->begin() + copy_span.second * NumFeaturesPerToken());
144*993b0882SAndroid Build Coastguard Worker   for (int i = copy_span.second; i < intended_span.second; ++i) {
145*993b0882SAndroid Build Coastguard Worker     AppendPaddingFeatures(output_features);
146*993b0882SAndroid Build Coastguard Worker   }
147*993b0882SAndroid Build Coastguard Worker }
148*993b0882SAndroid Build Coastguard Worker 
AppendPaddingFeatures(std::vector<float> * output_features) const149*993b0882SAndroid Build Coastguard Worker void CachedFeatures::AppendPaddingFeatures(
150*993b0882SAndroid Build Coastguard Worker     std::vector<float>* output_features) const {
151*993b0882SAndroid Build Coastguard Worker   output_features->insert(output_features->end(), padding_features_->begin(),
152*993b0882SAndroid Build Coastguard Worker                           padding_features_->end());
153*993b0882SAndroid Build Coastguard Worker }
154*993b0882SAndroid Build Coastguard Worker 
AppendBagFeatures(const TokenSpan & bag_span,std::vector<float> * output_features) const155*993b0882SAndroid Build Coastguard Worker void CachedFeatures::AppendBagFeatures(
156*993b0882SAndroid Build Coastguard Worker     const TokenSpan& bag_span, std::vector<float>* output_features) const {
157*993b0882SAndroid Build Coastguard Worker   const int offset = output_features->size();
158*993b0882SAndroid Build Coastguard Worker   output_features->resize(output_features->size() + NumFeaturesPerToken());
159*993b0882SAndroid Build Coastguard Worker   for (int i = bag_span.first; i < bag_span.second; ++i) {
160*993b0882SAndroid Build Coastguard Worker     for (int j = 0; j < NumFeaturesPerToken(); ++j) {
161*993b0882SAndroid Build Coastguard Worker       (*output_features)[offset + j] +=
162*993b0882SAndroid Build Coastguard Worker           (*features_)[i * NumFeaturesPerToken() + j] / bag_span.Size();
163*993b0882SAndroid Build Coastguard Worker     }
164*993b0882SAndroid Build Coastguard Worker   }
165*993b0882SAndroid Build Coastguard Worker }
166*993b0882SAndroid Build Coastguard Worker 
NumFeaturesPerToken() const167*993b0882SAndroid Build Coastguard Worker int CachedFeatures::NumFeaturesPerToken() const {
168*993b0882SAndroid Build Coastguard Worker   return padding_features_->size();
169*993b0882SAndroid Build Coastguard Worker }
170*993b0882SAndroid Build Coastguard Worker 
171*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
172