xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/common/fel/feature-extractor.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/common/fel/feature-extractor.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "lang_id/common/fel/feature-types.h"
23 #include "lang_id/common/fel/fel-parser.h"
24 #include "lang_id/common/lite_base/logging.h"
25 #include "lang_id/common/lite_strings/numbers.h"
26 
27 namespace libtextclassifier3 {
28 namespace mobile {
29 
30 constexpr FeatureValue GenericFeatureFunction::kNone;
31 
GenericFeatureExtractor()32 GenericFeatureExtractor::GenericFeatureExtractor() {}
33 
~GenericFeatureExtractor()34 GenericFeatureExtractor::~GenericFeatureExtractor() {}
35 
Parse(const std::string & source)36 bool GenericFeatureExtractor::Parse(const std::string &source) {
37   // Parse feature specification into descriptor.
38   FELParser parser;
39 
40   if (!parser.Parse(source, mutable_descriptor())) {
41     SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
42     return false;
43   }
44 
45   // Initialize feature extractor from descriptor.
46   return InitializeFeatureFunctions();
47 }
48 
InitializeFeatureTypes()49 bool GenericFeatureExtractor::InitializeFeatureTypes() {
50   // Register all feature types.
51   GetFeatureTypes(&feature_types_);
52   for (size_t i = 0; i < feature_types_.size(); ++i) {
53     FeatureType *ft = feature_types_[i];
54     ft->set_base(i);
55 
56     // Check for feature space overflow.
57     double domain_size = ft->GetDomainSize();
58     if (domain_size < 0) {
59       SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
60                        << ": " << domain_size;
61       return false;
62     }
63   }
64   return true;
65 }
66 
GetParameter(const std::string & name,const std::string & default_value) const67 std::string GenericFeatureFunction::GetParameter(
68     const std::string &name, const std::string &default_value) const {
69   // Find named parameter in feature descriptor.
70   for (int i = 0; i < descriptor_->parameter_size(); ++i) {
71     if (name == descriptor_->parameter(i).name()) {
72       return descriptor_->parameter(i).value();
73     }
74   }
75   return default_value;
76 }
77 
GenericFeatureFunction()78 GenericFeatureFunction::GenericFeatureFunction() {}
79 
~GenericFeatureFunction()80 GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
81 
GetIntParameter(const std::string & name,int default_value) const82 int GenericFeatureFunction::GetIntParameter(const std::string &name,
83                                             int default_value) const {
84   std::string value_str = GetParameter(name, "");
85   if (value_str.empty()) {
86     // Parameter not specified, use default value for it.
87     return default_value;
88   }
89   int value = 0;
90   if (!LiteAtoi(value_str, &value)) {
91     SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
92                       << "' as int for parameter " << name;
93     return default_value;
94   }
95   return value;
96 }
97 
GetBoolParameter(const std::string & name,bool default_value) const98 bool GenericFeatureFunction::GetBoolParameter(const std::string &name,
99                                               bool default_value) const {
100   std::string value = GetParameter(name, "");
101   if (value.empty()) return default_value;
102   if (value == "true") return true;
103   if (value == "false") return false;
104   SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
105                     << name;
106   return default_value;
107 }
108 
GetFeatureTypes(std::vector<FeatureType * > * types) const109 void GenericFeatureFunction::GetFeatureTypes(
110     std::vector<FeatureType *> *types) const {
111   if (feature_type_ != nullptr) types->push_back(feature_type_);
112 }
113 
GetFeatureType() const114 FeatureType *GenericFeatureFunction::GetFeatureType() const {
115   // If a single feature type has been registered return it.
116   if (feature_type_ != nullptr) return feature_type_;
117 
118   // Get feature types for function.
119   std::vector<FeatureType *> types;
120   GetFeatureTypes(&types);
121 
122   // If there is exactly one feature type return this, else return null.
123   if (types.size() == 1) return types[0];
124   return nullptr;
125 }
126 
name() const127 std::string GenericFeatureFunction::name() const {
128   std::string output;
129   if (descriptor_->name().empty()) {
130     if (!prefix_.empty()) {
131       output.append(prefix_);
132       output.append(".");
133     }
134     ToFEL(*descriptor_, &output);
135   } else {
136     output = descriptor_->name();
137   }
138   return output;
139 }
140 
141 }  // namespace mobile
142 }  // namespace nlp_saft
143