1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ 18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ 19*993b0882SAndroid Build Coastguard Worker 20*993b0882SAndroid Build Coastguard Worker #include <memory> 21*993b0882SAndroid Build Coastguard Worker #include <vector> 22*993b0882SAndroid Build Coastguard Worker 23*993b0882SAndroid Build Coastguard Worker #include "annotator/model-executor.h" 24*993b0882SAndroid Build Coastguard Worker #include "annotator/model_generated.h" 25*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h" 26*993b0882SAndroid Build Coastguard Worker 27*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 28*993b0882SAndroid Build Coastguard Worker 29*993b0882SAndroid Build Coastguard Worker // Holds state for extracting features across multiple calls and reusing them. 30*993b0882SAndroid Build Coastguard Worker // Assumes that features for each Token are independent. 31*993b0882SAndroid Build Coastguard Worker class CachedFeatures { 32*993b0882SAndroid Build Coastguard Worker public: 33*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<CachedFeatures> Create( 34*993b0882SAndroid Build Coastguard Worker const TokenSpan& extraction_span, 35*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> features, 36*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> padding_features, 37*993b0882SAndroid Build Coastguard Worker const FeatureProcessorOptions* options, int feature_vector_size); 38*993b0882SAndroid Build Coastguard Worker 39*993b0882SAndroid Build Coastguard Worker // Appends the click context features for the given click position to 40*993b0882SAndroid Build Coastguard Worker // 'output_features'. 41*993b0882SAndroid Build Coastguard Worker void AppendClickContextFeaturesForClick( 42*993b0882SAndroid Build Coastguard Worker int click_pos, std::vector<float>* output_features) const; 43*993b0882SAndroid Build Coastguard Worker 44*993b0882SAndroid Build Coastguard Worker // Appends the bounds-sensitive features for the given token span to 45*993b0882SAndroid Build Coastguard Worker // 'output_features'. 46*993b0882SAndroid Build Coastguard Worker void AppendBoundsSensitiveFeaturesForSpan( 47*993b0882SAndroid Build Coastguard Worker TokenSpan selected_span, std::vector<float>* output_features) const; 48*993b0882SAndroid Build Coastguard Worker 49*993b0882SAndroid Build Coastguard Worker // Returns number of features that 'AppendFeaturesForSpan' appends. OutputFeaturesSize()50*993b0882SAndroid Build Coastguard Worker int OutputFeaturesSize() const { return output_features_size_; } 51*993b0882SAndroid Build Coastguard Worker 52*993b0882SAndroid Build Coastguard Worker private: CachedFeatures()53*993b0882SAndroid Build Coastguard Worker CachedFeatures() {} 54*993b0882SAndroid Build Coastguard Worker 55*993b0882SAndroid Build Coastguard Worker // Appends token features to the output. The intended_span specifies which 56*993b0882SAndroid Build Coastguard Worker // tokens' features should be used in principle. The read_mask_span restricts 57*993b0882SAndroid Build Coastguard Worker // which tokens are actually read. For tokens outside of the read_mask_span, 58*993b0882SAndroid Build Coastguard Worker // padding tokens are used instead. 59*993b0882SAndroid Build Coastguard Worker void AppendFeaturesInternal(const TokenSpan& intended_span, 60*993b0882SAndroid Build Coastguard Worker const TokenSpan& read_mask_span, 61*993b0882SAndroid Build Coastguard Worker std::vector<float>* output_features) const; 62*993b0882SAndroid Build Coastguard Worker 63*993b0882SAndroid Build Coastguard Worker // Appends features of one padding token to the output. 64*993b0882SAndroid Build Coastguard Worker void AppendPaddingFeatures(std::vector<float>* output_features) const; 65*993b0882SAndroid Build Coastguard Worker 66*993b0882SAndroid Build Coastguard Worker // Appends the features of tokens from the given span to the output. The 67*993b0882SAndroid Build Coastguard Worker // features are averaged so that the appended features have the size 68*993b0882SAndroid Build Coastguard Worker // corresponding to one token. 69*993b0882SAndroid Build Coastguard Worker void AppendBagFeatures(const TokenSpan& bag_span, 70*993b0882SAndroid Build Coastguard Worker std::vector<float>* output_features) const; 71*993b0882SAndroid Build Coastguard Worker 72*993b0882SAndroid Build Coastguard Worker int NumFeaturesPerToken() const; 73*993b0882SAndroid Build Coastguard Worker 74*993b0882SAndroid Build Coastguard Worker TokenSpan extraction_span_; 75*993b0882SAndroid Build Coastguard Worker const FeatureProcessorOptions* options_; 76*993b0882SAndroid Build Coastguard Worker int output_features_size_; 77*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> features_; 78*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> padding_features_; 79*993b0882SAndroid Build Coastguard Worker }; 80*993b0882SAndroid Build Coastguard Worker 81*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 82*993b0882SAndroid Build Coastguard Worker 83*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_ANNOTATOR_CACHED_FEATURES_H_ 84