xref: /aosp_15_r20/external/libtextclassifier/native/annotator/quantization_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "annotator/quantization.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <vector>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
22*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker using testing::ElementsAreArray;
25*993b0882SAndroid Build Coastguard Worker using testing::FloatEq;
26*993b0882SAndroid Build Coastguard Worker using testing::Matcher;
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
29*993b0882SAndroid Build Coastguard Worker namespace {
30*993b0882SAndroid Build Coastguard Worker 
ElementsAreFloat(const std::vector<float> & values)31*993b0882SAndroid Build Coastguard Worker Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
32*993b0882SAndroid Build Coastguard Worker   std::vector<Matcher<float>> matchers;
33*993b0882SAndroid Build Coastguard Worker   for (const float value : values) {
34*993b0882SAndroid Build Coastguard Worker     matchers.push_back(FloatEq(value));
35*993b0882SAndroid Build Coastguard Worker   }
36*993b0882SAndroid Build Coastguard Worker   return ElementsAreArray(matchers);
37*993b0882SAndroid Build Coastguard Worker }
38*993b0882SAndroid Build Coastguard Worker 
TEST(QuantizationTest,DequantizeAdd8bit)39*993b0882SAndroid Build Coastguard Worker TEST(QuantizationTest, DequantizeAdd8bit) {
40*993b0882SAndroid Build Coastguard Worker   std::vector<float> scales{{0.1, 9.0, -7.0}};
41*993b0882SAndroid Build Coastguard Worker   std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
42*993b0882SAndroid Build Coastguard Worker                                  /*1: */ 0xFF, 0x09, 0x00, 0xFF,
43*993b0882SAndroid Build Coastguard Worker                                  /*2: */ 0x09, 0x00, 0xFF, 0x09}};
44*993b0882SAndroid Build Coastguard Worker 
45*993b0882SAndroid Build Coastguard Worker   const int quantization_bits = 8;
46*993b0882SAndroid Build Coastguard Worker   const int bytes_per_embedding = 4;
47*993b0882SAndroid Build Coastguard Worker   const int num_sparse_features = 7;
48*993b0882SAndroid Build Coastguard Worker   {
49*993b0882SAndroid Build Coastguard Worker     const int bucket_id = 0;
50*993b0882SAndroid Build Coastguard Worker     std::vector<float> dest(4, 0.0);
51*993b0882SAndroid Build Coastguard Worker     DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
52*993b0882SAndroid Build Coastguard Worker                   num_sparse_features, quantization_bits, bucket_id,
53*993b0882SAndroid Build Coastguard Worker                   dest.data(), dest.size());
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(dest,
56*993b0882SAndroid Build Coastguard Worker                 ElementsAreFloat(std::vector<float>{
57*993b0882SAndroid Build Coastguard Worker                     // clang-format off
58*993b0882SAndroid Build Coastguard Worker                     {1.0 / 7 * 0.1 * (0x00 - 128),
59*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 0.1 * (0xFF - 128),
60*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 0.1 * (0x09 - 128),
61*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 0.1 * (0x00 - 128)}
62*993b0882SAndroid Build Coastguard Worker                     // clang-format on
63*993b0882SAndroid Build Coastguard Worker                 }));
64*993b0882SAndroid Build Coastguard Worker   }
65*993b0882SAndroid Build Coastguard Worker 
66*993b0882SAndroid Build Coastguard Worker   {
67*993b0882SAndroid Build Coastguard Worker     const int bucket_id = 1;
68*993b0882SAndroid Build Coastguard Worker     std::vector<float> dest(4, 0.0);
69*993b0882SAndroid Build Coastguard Worker     DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
70*993b0882SAndroid Build Coastguard Worker                   num_sparse_features, quantization_bits, bucket_id,
71*993b0882SAndroid Build Coastguard Worker                   dest.data(), dest.size());
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(dest,
74*993b0882SAndroid Build Coastguard Worker                 ElementsAreFloat(std::vector<float>{
75*993b0882SAndroid Build Coastguard Worker                     // clang-format off
76*993b0882SAndroid Build Coastguard Worker                     {1.0 / 7 * 9.0 * (0xFF - 128),
77*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 9.0 * (0x09 - 128),
78*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 9.0 * (0x00 - 128),
79*993b0882SAndroid Build Coastguard Worker                      1.0 / 7 * 9.0 * (0xFF - 128)}
80*993b0882SAndroid Build Coastguard Worker                     // clang-format on
81*993b0882SAndroid Build Coastguard Worker                 }));
82*993b0882SAndroid Build Coastguard Worker   }
83*993b0882SAndroid Build Coastguard Worker }
84*993b0882SAndroid Build Coastguard Worker 
TEST(QuantizationTest,DequantizeAdd1bitZeros)85*993b0882SAndroid Build Coastguard Worker TEST(QuantizationTest, DequantizeAdd1bitZeros) {
86*993b0882SAndroid Build Coastguard Worker   const int bytes_per_embedding = 4;
87*993b0882SAndroid Build Coastguard Worker   const int num_buckets = 3;
88*993b0882SAndroid Build Coastguard Worker   const int num_sparse_features = 7;
89*993b0882SAndroid Build Coastguard Worker   const int quantization_bits = 1;
90*993b0882SAndroid Build Coastguard Worker   const int bucket_id = 1;
91*993b0882SAndroid Build Coastguard Worker 
92*993b0882SAndroid Build Coastguard Worker   std::vector<float> scales(num_buckets);
93*993b0882SAndroid Build Coastguard Worker   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
94*993b0882SAndroid Build Coastguard Worker   std::fill(scales.begin(), scales.end(), 1);
95*993b0882SAndroid Build Coastguard Worker   std::fill(embeddings.begin(), embeddings.end(), 0);
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker   std::vector<float> dest(32);
98*993b0882SAndroid Build Coastguard Worker   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
99*993b0882SAndroid Build Coastguard Worker                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
100*993b0882SAndroid Build Coastguard Worker                 dest.size());
101*993b0882SAndroid Build Coastguard Worker 
102*993b0882SAndroid Build Coastguard Worker   std::vector<float> expected(32);
103*993b0882SAndroid Build Coastguard Worker   std::fill(expected.begin(), expected.end(),
104*993b0882SAndroid Build Coastguard Worker             1.0 / num_sparse_features * (0 - 1));
105*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(dest, ElementsAreFloat(expected));
106*993b0882SAndroid Build Coastguard Worker }
107*993b0882SAndroid Build Coastguard Worker 
TEST(QuantizationTest,DequantizeAdd1bitOnes)108*993b0882SAndroid Build Coastguard Worker TEST(QuantizationTest, DequantizeAdd1bitOnes) {
109*993b0882SAndroid Build Coastguard Worker   const int bytes_per_embedding = 4;
110*993b0882SAndroid Build Coastguard Worker   const int num_buckets = 3;
111*993b0882SAndroid Build Coastguard Worker   const int num_sparse_features = 7;
112*993b0882SAndroid Build Coastguard Worker   const int quantization_bits = 1;
113*993b0882SAndroid Build Coastguard Worker   const int bucket_id = 1;
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker   std::vector<float> scales(num_buckets, 1.0);
116*993b0882SAndroid Build Coastguard Worker   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
117*993b0882SAndroid Build Coastguard Worker 
118*993b0882SAndroid Build Coastguard Worker   std::vector<float> dest(32);
119*993b0882SAndroid Build Coastguard Worker   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
120*993b0882SAndroid Build Coastguard Worker                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
121*993b0882SAndroid Build Coastguard Worker                 dest.size());
122*993b0882SAndroid Build Coastguard Worker   std::vector<float> expected(32);
123*993b0882SAndroid Build Coastguard Worker   std::fill(expected.begin(), expected.end(),
124*993b0882SAndroid Build Coastguard Worker             1.0 / num_sparse_features * (1 - 1));
125*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(dest, ElementsAreFloat(expected));
126*993b0882SAndroid Build Coastguard Worker }
127*993b0882SAndroid Build Coastguard Worker 
TEST(QuantizationTest,DequantizeAdd3bit)128*993b0882SAndroid Build Coastguard Worker TEST(QuantizationTest, DequantizeAdd3bit) {
129*993b0882SAndroid Build Coastguard Worker   const int bytes_per_embedding = 4;
130*993b0882SAndroid Build Coastguard Worker   const int num_buckets = 3;
131*993b0882SAndroid Build Coastguard Worker   const int num_sparse_features = 7;
132*993b0882SAndroid Build Coastguard Worker   const int quantization_bits = 3;
133*993b0882SAndroid Build Coastguard Worker   const int bucket_id = 1;
134*993b0882SAndroid Build Coastguard Worker 
135*993b0882SAndroid Build Coastguard Worker   std::vector<float> scales(num_buckets, 1.0);
136*993b0882SAndroid Build Coastguard Worker   scales[1] = 9.0;
137*993b0882SAndroid Build Coastguard Worker   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
138*993b0882SAndroid Build Coastguard Worker   // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
139*993b0882SAndroid Build Coastguard Worker   embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
140*993b0882SAndroid Build Coastguard Worker   embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
141*993b0882SAndroid Build Coastguard Worker   embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
142*993b0882SAndroid Build Coastguard Worker 
143*993b0882SAndroid Build Coastguard Worker   std::vector<float> dest(10);
144*993b0882SAndroid Build Coastguard Worker   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
145*993b0882SAndroid Build Coastguard Worker                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
146*993b0882SAndroid Build Coastguard Worker                 dest.size());
147*993b0882SAndroid Build Coastguard Worker 
148*993b0882SAndroid Build Coastguard Worker   std::vector<float> expected;
149*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
150*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
151*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
152*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
153*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
154*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
155*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
156*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
157*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
158*993b0882SAndroid Build Coastguard Worker   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
159*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(dest, ElementsAreFloat(expected));
160*993b0882SAndroid Build Coastguard Worker }
161*993b0882SAndroid Build Coastguard Worker 
162*993b0882SAndroid Build Coastguard Worker }  // namespace
163*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
164