1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/features/char-ngram-feature.h"
18
19 #include <mutex>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "lang_id/common/fel/feature-types.h"
25 #include "lang_id/common/fel/task-context.h"
26 #include "lang_id/common/lite_base/logging.h"
27 #include "lang_id/common/math/hash.h"
28 #include "lang_id/common/utf8.h"
29
30 namespace libtextclassifier3 {
31 namespace mobile {
32 namespace lang_id {
33
Setup(TaskContext * context)34 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
35 // Parameters in the feature function descriptor.
36 bool include_terminators = GetBoolParameter("include_terminators", false);
37 if (!include_terminators) {
38 SAFTM_LOG(ERROR) << "No support for include_terminators=true";
39 return false;
40 }
41
42 bool include_spaces = GetBoolParameter("include_spaces", false);
43 if (include_spaces) {
44 SAFTM_LOG(ERROR) << "No support for include_spaces=true";
45 return false;
46 }
47
48 bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
49 if (use_equal_ngram_weight) {
50 SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
51 return false;
52 }
53
54 ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
55 ngram_size_ = GetIntParameter("size", 3);
56
57 counts_.assign(ngram_id_dimension_, 0);
58 return true;
59 }
60
Init(TaskContext * context)61 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
62 set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
63 return true;
64 }
65
ComputeNgramCounts(const LightSentence & sentence) const66 int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
67 const LightSentence &sentence) const {
68 SAFTM_CHECK_EQ(static_cast<int>(counts_.size()), ngram_id_dimension_);
69 SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0u);
70
71 int total_count = 0;
72
73 for (const std::string &word : sentence) {
74 const char *const word_end = word.data() + word.size();
75
76 // Set ngram_start at the start of the current token (word).
77 const char *ngram_start = word.data();
78
79 // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
80 // UTF8 character contains between 1 and 4 bytes.
81 const char *ngram_end = ngram_start;
82 int num_utf8_chars = 0;
83 do {
84 ngram_end += utils::OneCharLen(ngram_end);
85 num_utf8_chars++;
86 } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
87
88 if (num_utf8_chars < ngram_size_) {
89 // Current token is so small, it does not contain a single ngram of
90 // ngram_size UTF8 characters. Not much we can do in this case ...
91 continue;
92 }
93
94 // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
95 // UTF8 characters from current token.
96 while (true) {
97 // Compute ngram id: hash(ngram) % ngram_id_dimension
98 int ngram_id = (
99 utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
100 % ngram_id_dimension_);
101
102 // Use a reference to the actual count, such that we can both test whether
103 // the count was 0 and increment it without perfoming two lookups.
104 int &ref_to_count_for_ngram = counts_[ngram_id];
105 if (ref_to_count_for_ngram == 0) {
106 non_zero_count_indices_.push_back(ngram_id);
107 }
108 ref_to_count_for_ngram++;
109 total_count++;
110 if (ngram_end >= word_end) {
111 break;
112 }
113
114 // Advance both ngram_start and ngram_end by one UTF8 character. This
115 // way, the number of UTF8 characters between them remains constant
116 // (ngram_size).
117 ngram_start += utils::OneCharLen(ngram_start);
118 ngram_end += utils::OneCharLen(ngram_end);
119 }
120 } // end of loop over tokens.
121
122 return total_count;
123 }
124
Evaluate(const WorkspaceSet & workspaces,const LightSentence & sentence,FeatureVector * result) const125 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
126 const LightSentence &sentence,
127 FeatureVector *result) const {
128 // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
129 // porting to Android and to avoid pulling in absl (which increases our code
130 // size).
131 std::lock_guard<std::mutex> mlock(state_mutex_);
132
133 // Find the char ngram counts.
134 int total_count = ComputeNgramCounts(sentence);
135
136 // Populate the feature vector.
137 const float norm = static_cast<float>(total_count);
138
139 // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
140 // elements) separately.
141 for (int ngram_id : non_zero_count_indices_) {
142 const float weight = counts_[ngram_id] / norm;
143 FloatFeatureValue value(ngram_id, weight);
144 result->add(feature_type(), value.discrete_value);
145
146 // Clear up counts_, for the next invocation of Evaluate().
147 counts_[ngram_id] = 0;
148 }
149
150 // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
151 non_zero_count_indices_.clear();
152 }
153
154 SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
155
156 } // namespace lang_id
157 } // namespace mobile
158 } // namespace nlp_saft
159