1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "lang_id/features/char-ngram-feature.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <mutex>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker #include <utility>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/fel/feature-types.h"
25*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/fel/task-context.h"
26*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/lite_base/logging.h"
27*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/math/hash.h"
28*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/utf8.h"
29*993b0882SAndroid Build Coastguard Worker
30*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
31*993b0882SAndroid Build Coastguard Worker namespace mobile {
32*993b0882SAndroid Build Coastguard Worker namespace lang_id {
33*993b0882SAndroid Build Coastguard Worker
Setup(TaskContext * context)34*993b0882SAndroid Build Coastguard Worker bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
35*993b0882SAndroid Build Coastguard Worker // Parameters in the feature function descriptor.
36*993b0882SAndroid Build Coastguard Worker bool include_terminators = GetBoolParameter("include_terminators", false);
37*993b0882SAndroid Build Coastguard Worker if (!include_terminators) {
38*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "No support for include_terminators=true";
39*993b0882SAndroid Build Coastguard Worker return false;
40*993b0882SAndroid Build Coastguard Worker }
41*993b0882SAndroid Build Coastguard Worker
42*993b0882SAndroid Build Coastguard Worker bool include_spaces = GetBoolParameter("include_spaces", false);
43*993b0882SAndroid Build Coastguard Worker if (include_spaces) {
44*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "No support for include_spaces=true";
45*993b0882SAndroid Build Coastguard Worker return false;
46*993b0882SAndroid Build Coastguard Worker }
47*993b0882SAndroid Build Coastguard Worker
48*993b0882SAndroid Build Coastguard Worker bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
49*993b0882SAndroid Build Coastguard Worker if (use_equal_ngram_weight) {
50*993b0882SAndroid Build Coastguard Worker SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
51*993b0882SAndroid Build Coastguard Worker return false;
52*993b0882SAndroid Build Coastguard Worker }
53*993b0882SAndroid Build Coastguard Worker
54*993b0882SAndroid Build Coastguard Worker ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
55*993b0882SAndroid Build Coastguard Worker ngram_size_ = GetIntParameter("size", 3);
56*993b0882SAndroid Build Coastguard Worker
57*993b0882SAndroid Build Coastguard Worker counts_.assign(ngram_id_dimension_, 0);
58*993b0882SAndroid Build Coastguard Worker return true;
59*993b0882SAndroid Build Coastguard Worker }
60*993b0882SAndroid Build Coastguard Worker
Init(TaskContext * context)61*993b0882SAndroid Build Coastguard Worker bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
62*993b0882SAndroid Build Coastguard Worker set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
63*993b0882SAndroid Build Coastguard Worker return true;
64*993b0882SAndroid Build Coastguard Worker }
65*993b0882SAndroid Build Coastguard Worker
ComputeNgramCounts(const LightSentence & sentence) const66*993b0882SAndroid Build Coastguard Worker int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
67*993b0882SAndroid Build Coastguard Worker const LightSentence &sentence) const {
68*993b0882SAndroid Build Coastguard Worker SAFTM_CHECK_EQ(static_cast<int>(counts_.size()), ngram_id_dimension_);
69*993b0882SAndroid Build Coastguard Worker SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0u);
70*993b0882SAndroid Build Coastguard Worker
71*993b0882SAndroid Build Coastguard Worker int total_count = 0;
72*993b0882SAndroid Build Coastguard Worker
73*993b0882SAndroid Build Coastguard Worker for (const std::string &word : sentence) {
74*993b0882SAndroid Build Coastguard Worker const char *const word_end = word.data() + word.size();
75*993b0882SAndroid Build Coastguard Worker
76*993b0882SAndroid Build Coastguard Worker // Set ngram_start at the start of the current token (word).
77*993b0882SAndroid Build Coastguard Worker const char *ngram_start = word.data();
78*993b0882SAndroid Build Coastguard Worker
79*993b0882SAndroid Build Coastguard Worker // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
80*993b0882SAndroid Build Coastguard Worker // UTF8 character contains between 1 and 4 bytes.
81*993b0882SAndroid Build Coastguard Worker const char *ngram_end = ngram_start;
82*993b0882SAndroid Build Coastguard Worker int num_utf8_chars = 0;
83*993b0882SAndroid Build Coastguard Worker do {
84*993b0882SAndroid Build Coastguard Worker ngram_end += utils::OneCharLen(ngram_end);
85*993b0882SAndroid Build Coastguard Worker num_utf8_chars++;
86*993b0882SAndroid Build Coastguard Worker } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
87*993b0882SAndroid Build Coastguard Worker
88*993b0882SAndroid Build Coastguard Worker if (num_utf8_chars < ngram_size_) {
89*993b0882SAndroid Build Coastguard Worker // Current token is so small, it does not contain a single ngram of
90*993b0882SAndroid Build Coastguard Worker // ngram_size UTF8 characters. Not much we can do in this case ...
91*993b0882SAndroid Build Coastguard Worker continue;
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker
94*993b0882SAndroid Build Coastguard Worker // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
95*993b0882SAndroid Build Coastguard Worker // UTF8 characters from current token.
96*993b0882SAndroid Build Coastguard Worker while (true) {
97*993b0882SAndroid Build Coastguard Worker // Compute ngram id: hash(ngram) % ngram_id_dimension
98*993b0882SAndroid Build Coastguard Worker int ngram_id = (
99*993b0882SAndroid Build Coastguard Worker utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
100*993b0882SAndroid Build Coastguard Worker % ngram_id_dimension_);
101*993b0882SAndroid Build Coastguard Worker
102*993b0882SAndroid Build Coastguard Worker // Use a reference to the actual count, such that we can both test whether
103*993b0882SAndroid Build Coastguard Worker // the count was 0 and increment it without perfoming two lookups.
104*993b0882SAndroid Build Coastguard Worker int &ref_to_count_for_ngram = counts_[ngram_id];
105*993b0882SAndroid Build Coastguard Worker if (ref_to_count_for_ngram == 0) {
106*993b0882SAndroid Build Coastguard Worker non_zero_count_indices_.push_back(ngram_id);
107*993b0882SAndroid Build Coastguard Worker }
108*993b0882SAndroid Build Coastguard Worker ref_to_count_for_ngram++;
109*993b0882SAndroid Build Coastguard Worker total_count++;
110*993b0882SAndroid Build Coastguard Worker if (ngram_end >= word_end) {
111*993b0882SAndroid Build Coastguard Worker break;
112*993b0882SAndroid Build Coastguard Worker }
113*993b0882SAndroid Build Coastguard Worker
114*993b0882SAndroid Build Coastguard Worker // Advance both ngram_start and ngram_end by one UTF8 character. This
115*993b0882SAndroid Build Coastguard Worker // way, the number of UTF8 characters between them remains constant
116*993b0882SAndroid Build Coastguard Worker // (ngram_size).
117*993b0882SAndroid Build Coastguard Worker ngram_start += utils::OneCharLen(ngram_start);
118*993b0882SAndroid Build Coastguard Worker ngram_end += utils::OneCharLen(ngram_end);
119*993b0882SAndroid Build Coastguard Worker }
120*993b0882SAndroid Build Coastguard Worker } // end of loop over tokens.
121*993b0882SAndroid Build Coastguard Worker
122*993b0882SAndroid Build Coastguard Worker return total_count;
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker
Evaluate(const WorkspaceSet & workspaces,const LightSentence & sentence,FeatureVector * result) const125*993b0882SAndroid Build Coastguard Worker void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
126*993b0882SAndroid Build Coastguard Worker const LightSentence &sentence,
127*993b0882SAndroid Build Coastguard Worker FeatureVector *result) const {
128*993b0882SAndroid Build Coastguard Worker // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
129*993b0882SAndroid Build Coastguard Worker // porting to Android and to avoid pulling in absl (which increases our code
130*993b0882SAndroid Build Coastguard Worker // size).
131*993b0882SAndroid Build Coastguard Worker std::lock_guard<std::mutex> mlock(state_mutex_);
132*993b0882SAndroid Build Coastguard Worker
133*993b0882SAndroid Build Coastguard Worker // Find the char ngram counts.
134*993b0882SAndroid Build Coastguard Worker int total_count = ComputeNgramCounts(sentence);
135*993b0882SAndroid Build Coastguard Worker
136*993b0882SAndroid Build Coastguard Worker // Populate the feature vector.
137*993b0882SAndroid Build Coastguard Worker const float norm = static_cast<float>(total_count);
138*993b0882SAndroid Build Coastguard Worker
139*993b0882SAndroid Build Coastguard Worker // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
140*993b0882SAndroid Build Coastguard Worker // elements) separately.
141*993b0882SAndroid Build Coastguard Worker for (int ngram_id : non_zero_count_indices_) {
142*993b0882SAndroid Build Coastguard Worker const float weight = counts_[ngram_id] / norm;
143*993b0882SAndroid Build Coastguard Worker FloatFeatureValue value(ngram_id, weight);
144*993b0882SAndroid Build Coastguard Worker result->add(feature_type(), value.discrete_value);
145*993b0882SAndroid Build Coastguard Worker
146*993b0882SAndroid Build Coastguard Worker // Clear up counts_, for the next invocation of Evaluate().
147*993b0882SAndroid Build Coastguard Worker counts_[ngram_id] = 0;
148*993b0882SAndroid Build Coastguard Worker }
149*993b0882SAndroid Build Coastguard Worker
150*993b0882SAndroid Build Coastguard Worker // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
151*993b0882SAndroid Build Coastguard Worker non_zero_count_indices_.clear();
152*993b0882SAndroid Build Coastguard Worker }
153*993b0882SAndroid Build Coastguard Worker
154*993b0882SAndroid Build Coastguard Worker SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
155*993b0882SAndroid Build Coastguard Worker
156*993b0882SAndroid Build Coastguard Worker } // namespace lang_id
157*993b0882SAndroid Build Coastguard Worker } // namespace mobile
158*993b0882SAndroid Build Coastguard Worker } // namespace nlp_saft
159