xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/common/embedding-network.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/common/embedding-network.h"
18 
19 #include <vector>
20 
21 #include "lang_id/common/lite_base/integral-types.h"
22 #include "lang_id/common/lite_base/logging.h"
23 
24 namespace libtextclassifier3 {
25 namespace mobile {
26 namespace {
27 
CheckNoQuantization(const EmbeddingNetworkParams::Matrix & matrix)28 void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
29   SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
30                  static_cast<int>(matrix.quant_type))
31       << "Quantization not allowed here";
32 }
33 
GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix & matrix)34 int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
35   int cols = matrix.cols;
36   QuantizationType quant_type = matrix.quant_type;
37   switch (quant_type) {
38     case QuantizationType::NONE:
39       return cols * sizeof(float);
40     case QuantizationType::UINT8:
41       return cols * sizeof(uint8);
42     case QuantizationType::UINT4:
43       SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
44       return cols / 2;
45     case QuantizationType::FLOAT16:
46       return cols * sizeof(float16);
47     default:
48       SAFTM_LOG(FATAL) << "Unknown quant type: "
49                        << static_cast<int>(quant_type);
50   }
51 }
52 
53 // Computes y = weights * Relu(x) + b where Relu is optionally applied.
54 //
55 // weights and b are the weight matrix, respectively the bias vector of a neural
56 // network layer.
57 //
58 // Note: in the research literature, usually Relu (the activation function) is
59 // the last part of a neural layer.  From that perspective, this function
60 // computes the Relu part of the previous layer (if any) and next the first half
61 // (the computation of the state) for the current layer.
62 //
63 // Note: weights is expected to be the transposed version of the real weight
64 // matrix.  Hence, instead of computing a linear combination of the columns of
65 // weights, we compute a linear combination of its rows; but we are mindful that
66 // these rows are the columns of the original matrix, hence the name
67 // weights_col_i in the code.
SparseReluProductPlusBias(bool apply_relu,const EmbeddingNetworkParams::Matrix & weights,const EmbeddingNetworkParams::Matrix & b,const std::vector<float> & x,std::vector<float> * y)68 void SparseReluProductPlusBias(bool apply_relu,
69                                const EmbeddingNetworkParams::Matrix &weights,
70                                const EmbeddingNetworkParams::Matrix &b,
71                                const std::vector<float> &x,
72                                std::vector<float> *y) {
73   // Initialize y to b.  b is a column matrix (i.e., nb.cols == 1); we already
74   // CHECK-ed that the EmbeddingNetwork constructor.
75   const float *b_start = reinterpret_cast<const float *>(b.elements);
76   SAFTM_DCHECK_EQ(b.cols, 1);
77   y->assign(b_start, b_start + b.rows);
78 
79   float *const y_data = y->data();
80   const int y_size = y->size();
81   SAFTM_CHECK_EQ(weights.cols, y_size);
82   const int x_size = x.size();
83   SAFTM_CHECK_EQ(weights.rows, x_size);
84 
85   // NOTE: the code below reads x_size * y_size elements from weights; these
86   // reads are safe as long as weights.elements contains weights.rows *
87   // weights.cols elements (where the element size depends on the quantization
88   // type).  That requirement is checked by the params provider, e.g., by
89   // EmbeddingNetworkParamsFromFlatbuffer.
90 
91   // There is some code duplication between the two main cases of the switch
92   // below: the idea was to "lift" the switch outside the loops, to reduce the
93   // number of tests at runtime.
94   switch (weights.quant_type) {
95     case QuantizationType::NONE: {
96       // We compute a linear combination of the rows from |weights|, using
97       // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
98       // gets multiplied by x[i] before being added with the other rows).  Note:
99       // elements of |weights| are stored in row-major order: first the elements
100       // of row #0, next the elements of row #1, etc.  In the comments below, we
101       // write "weights[i][j]" to refer to the j-th element from the i-th row of
102       // weights.
103       const float *weight_ptr =
104           reinterpret_cast<const float *>(weights.elements);
105       for (int i = 0; i < x_size; ++i) {
106         // Invariant 1: weight_ptr points to the beginning of the i-th row from
107         // weights (i.e., weights[i][0]).
108         const float scale = x[i];
109         if (!apply_relu || (scale > 0)) {
110           for (int j = 0; j < y_size; ++j, ++weight_ptr) {
111             // Invariant 2: weight_ptr points to weights[i][j].
112             y_data[j] += (*weight_ptr) * scale;
113           }
114         } else {
115           // We don't update y_data, but we still have to move weight_ptr to the
116           // next row (to satisfy Invariant 1).  We do this by adding y_size ==
117           // weights.cols() (see earlier CHECK_EQ).
118           weight_ptr += y_size;
119         }
120       }
121       break;
122     }
123     case QuantizationType::FLOAT16: {
124       // See comments for the QuantizationType::NONE case: the code is almost
125       // identical, except for float16 (instead of float) and the Float16To32
126       // conversion.  We could unify these two cases using a template, but since
127       // this is a critical loop, don't want to risk that e.g., inlining of the
128       // conversion function doesn't happen.
129       const float16 *weight_ptr =
130           reinterpret_cast<const float16 *>(weights.elements);
131       for (int i = 0; i < x_size; ++i) {
132         const float scale = x[i];
133         if (!apply_relu || (scale > 0)) {
134           for (int j = 0; j < y_size; ++j, ++weight_ptr) {
135             y_data[j] += Float16To32(*weight_ptr) * scale;
136           }
137         } else {
138           weight_ptr += y_size;
139         }
140       }
141       break;
142     }
143     default:
144       SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
145                        << static_cast<int>(weights.quant_type);
146   }
147 }
148 }  // namespace
149 
ConcatEmbeddings(const std::vector<FeatureVector> & feature_vectors,std::vector<float> * concat) const150 void EmbeddingNetwork::ConcatEmbeddings(
151     const std::vector<FeatureVector> &feature_vectors,
152     std::vector<float> *concat) const {
153   concat->resize(concat_layer_size_);
154 
155   // "es_index" stands for "embedding space index".
156   for (size_t es_index = 0; es_index < feature_vectors.size(); ++es_index) {
157     const int concat_offset = concat_offset_[es_index];
158 
159     const EmbeddingNetworkParams::Matrix &embedding_matrix =
160         embedding_matrices_[es_index];
161     const int embedding_dim = embedding_matrix.cols;
162     const int embedding_row_size_in_bytes =
163         embedding_row_size_in_bytes_[es_index];
164 
165     const FeatureVector &feature_vector = feature_vectors[es_index];
166     const int num_features = feature_vector.size();
167     for (int fi = 0; fi < num_features; ++fi) {
168       const FeatureType *feature_type = feature_vector.type(fi);
169       int feature_offset = concat_offset + feature_type->base() * embedding_dim;
170       SAFTM_CHECK_LE(feature_offset + embedding_dim,
171                      static_cast<int>(concat->size()));
172 
173       // Weighted embeddings will be added starting from this address.
174       float *concat_ptr = concat->data() + feature_offset;
175 
176       // Multiplier for each embedding weight.  Includes feature weight (for
177       // continuous features) and quantization scale (for quantized embeddings).
178       float multiplier;
179       int feature_id;
180       const FeatureValue feature_value = feature_vector.value(fi);
181       if (feature_type->is_continuous()) {
182         // Continuous features (encoded as FloatFeatureValue).
183         FloatFeatureValue float_feature_value(feature_value);
184         feature_id = float_feature_value.id;
185         multiplier = float_feature_value.weight;
186       } else {
187         // Discrete features: every present feature has implicit value 1.0.
188         feature_id = feature_value;
189         multiplier = 1.0;
190       }
191 
192       SAFTM_CHECK_GE(feature_id, 0);
193       SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);
194 
195       // Pointer to float / uint8 weights for relevant embedding.
196       const void *embedding_data =
197           (reinterpret_cast<const char *>(embedding_matrix.elements) +
198            feature_id * embedding_row_size_in_bytes);
199 
200       switch (embedding_matrix.quant_type) {
201         case QuantizationType::NONE: {
202           const float *weights =
203               reinterpret_cast<const float *>(embedding_data);
204           for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
205             *concat_ptr += *weights * multiplier;
206           }
207           break;
208         }
209         case QuantizationType::UINT8: {
210           multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
211           const uint8 *quant_weights =
212               reinterpret_cast<const uint8 *>(embedding_data);
213           for (int i = 0; i < embedding_dim;
214                ++i, ++quant_weights, ++concat_ptr) {
215             // 128 is bias for UINT8 quantization.
216             *concat_ptr +=
217                 (static_cast<int>(*quant_weights) - 128) * multiplier;
218           }
219           break;
220         }
221         case QuantizationType::UINT4: {
222           multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
223           const uint8 *quant_weights =
224               reinterpret_cast<const uint8 *>(embedding_data);
225           for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
226             const uint8 qq = *quant_weights;
227             concat_ptr[0] +=
228                 (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
229             concat_ptr[1] +=
230                 (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
231                 multiplier;
232             concat_ptr += 2;
233           }
234           break;
235         }
236         default:
237           // We already checked (in GetMatrixRowSizeInBytes) that each embedding
238           // matrix has a known quantization type.  Hence, DLOG is enough here.
239           SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
240                             << static_cast<int>(embedding_matrix.quant_type);
241           break;
242       }
243     }
244   }
245 }
246 
ComputeFinalScores(const std::vector<FeatureVector> & features,std::vector<float> * scores) const247 void EmbeddingNetwork::ComputeFinalScores(
248     const std::vector<FeatureVector> &features,
249     std::vector<float> *scores) const {
250   ComputeFinalScores(features, {}, scores);
251 }
252 
ComputeFinalScores(const std::vector<FeatureVector> & features,const std::vector<float> & extra_inputs,std::vector<float> * scores) const253 void EmbeddingNetwork::ComputeFinalScores(
254     const std::vector<FeatureVector> &features,
255     const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
256   // Construct the input layer for our feed-forward neural network (FFNN).
257   std::vector<float> input;
258   ConcatEmbeddings(features, &input);
259   if (!extra_inputs.empty()) {
260     input.reserve(input.size() + extra_inputs.size());
261     for (size_t i = 0; i < extra_inputs.size(); i++) {
262       input.push_back(extra_inputs[i]);
263     }
264   }
265 
266   // Propagate input through all layers of our FFNN.
267 
268   // Alternating storage for activations of the different layers.  We can't use
269   // a single vector because all activations of the previous layer are required
270   // when computing the activations of the next one.
271   std::vector<float> storage[2];
272   const std::vector<float> *v_in = &input;
273   const int num_layers = layer_weights_.size();
274   for (int i = 0; i < num_layers; ++i) {
275     std::vector<float> *v_out = nullptr;
276     if (i == num_layers - 1) {
277       // Final layer: write results directly into |scores|.
278       v_out = scores;
279     } else {
280       // Hidden layer: write results into the alternating storage.  The i % 2
281       // trick ensures the alternation.
282       v_out = &(storage[i % 2]);
283     }
284     const bool apply_relu = i > 0;
285     SparseReluProductPlusBias(apply_relu, layer_weights_[i], layer_bias_[i],
286                               *v_in, v_out);
287     v_in = v_out;
288   }
289 }
290 
EmbeddingNetwork(const EmbeddingNetworkParams * model)291 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
292     : model_(model) {
293   int offset_sum = 0;
294   for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
295     concat_offset_.push_back(offset_sum);
296     EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
297     offset_sum += matrix.cols * model_->embedding_num_features(i);
298 
299     // NOTE: each Matrix is a small struct that doesn't own the actual matrix
300     // weights.  Hence, the push_back below is fast.
301     embedding_matrices_.push_back(matrix);
302     embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
303   }
304   concat_layer_size_ = offset_sum;
305 
306   SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
307   for (int i = 0; i < model_->hidden_size(); ++i) {
308     layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));
309 
310     EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
311     SAFTM_CHECK_EQ(1, bias.cols);
312     CheckNoQuantization(bias);
313     layer_bias_.push_back(bias);
314   }
315 
316   SAFTM_CHECK(model_->HasSoftmax());
317   layer_weights_.push_back(model_->GetSoftmaxMatrix());
318 
319   EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
320   SAFTM_CHECK_EQ(1, softmax_bias.cols);
321   CheckNoQuantization(softmax_bias);
322   layer_bias_.push_back(softmax_bias);
323 }
324 
325 }  // namespace mobile
326 }  // namespace nlp_saft
327