xref: /aosp_15_r20/external/libtextclassifier/native/actions/feature-processor_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
20*993b0882SAndroid Build Coastguard Worker #include "annotator/model-executor.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h"
22*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
23*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker using ::testing::FloatEq;
29*993b0882SAndroid Build Coastguard Worker using ::testing::SizeIs;
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker // EmbeddingExecutor that always returns features based on
32*993b0882SAndroid Build Coastguard Worker // the id of the sparse features.
33*993b0882SAndroid Build Coastguard Worker class FakeEmbeddingExecutor : public EmbeddingExecutor {
34*993b0882SAndroid Build Coastguard Worker  public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const35*993b0882SAndroid Build Coastguard Worker   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
36*993b0882SAndroid Build Coastguard Worker                     const int dest_size) const override {
37*993b0882SAndroid Build Coastguard Worker     TC3_CHECK_GE(dest_size, 4);
38*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(sparse_features, SizeIs(1));
39*993b0882SAndroid Build Coastguard Worker     dest[0] = sparse_features.data()[0];
40*993b0882SAndroid Build Coastguard Worker     dest[1] = sparse_features.data()[0];
41*993b0882SAndroid Build Coastguard Worker     dest[2] = -sparse_features.data()[0];
42*993b0882SAndroid Build Coastguard Worker     dest[3] = -sparse_features.data()[0];
43*993b0882SAndroid Build Coastguard Worker     return true;
44*993b0882SAndroid Build Coastguard Worker   }
45*993b0882SAndroid Build Coastguard Worker 
46*993b0882SAndroid Build Coastguard Worker  private:
47*993b0882SAndroid Build Coastguard Worker   std::vector<float> storage_;
48*993b0882SAndroid Build Coastguard Worker };
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker class ActionsFeatureProcessorTest : public ::testing::Test {
51*993b0882SAndroid Build Coastguard Worker  protected:
ActionsFeatureProcessorTest()52*993b0882SAndroid Build Coastguard Worker   ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
53*993b0882SAndroid Build Coastguard Worker 
PackFeatureProcessorOptions(ActionsTokenFeatureProcessorOptionsT * options) const54*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
55*993b0882SAndroid Build Coastguard Worker       ActionsTokenFeatureProcessorOptionsT* options) const {
56*993b0882SAndroid Build Coastguard Worker     flatbuffers::FlatBufferBuilder builder;
57*993b0882SAndroid Build Coastguard Worker     builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
58*993b0882SAndroid Build Coastguard Worker     return builder.Release();
59*993b0882SAndroid Build Coastguard Worker   }
60*993b0882SAndroid Build Coastguard Worker 
61*993b0882SAndroid Build Coastguard Worker   FakeEmbeddingExecutor embedding_executor_;
62*993b0882SAndroid Build Coastguard Worker   UniLib unilib_;
63*993b0882SAndroid Build Coastguard Worker };
64*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddings)65*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
66*993b0882SAndroid Build Coastguard Worker   ActionsTokenFeatureProcessorOptionsT options;
67*993b0882SAndroid Build Coastguard Worker   options.embedding_size = 4;
68*993b0882SAndroid Build Coastguard Worker   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
69*993b0882SAndroid Build Coastguard Worker 
70*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer options_fb =
71*993b0882SAndroid Build Coastguard Worker       PackFeatureProcessorOptions(&options);
72*993b0882SAndroid Build Coastguard Worker   ActionsFeatureProcessor feature_processor(
73*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
74*993b0882SAndroid Build Coastguard Worker           options_fb.data()),
75*993b0882SAndroid Build Coastguard Worker       &unilib_);
76*993b0882SAndroid Build Coastguard Worker 
77*993b0882SAndroid Build Coastguard Worker   Token token("aaa", 0, 3);
78*993b0882SAndroid Build Coastguard Worker   std::vector<float> token_features;
79*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
80*993b0882SAndroid Build Coastguard Worker                                                     &token_features));
81*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features, SizeIs(4));
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddingsCaseFeature)84*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
85*993b0882SAndroid Build Coastguard Worker   ActionsTokenFeatureProcessorOptionsT options;
86*993b0882SAndroid Build Coastguard Worker   options.embedding_size = 4;
87*993b0882SAndroid Build Coastguard Worker   options.extract_case_feature = true;
88*993b0882SAndroid Build Coastguard Worker   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer options_fb =
91*993b0882SAndroid Build Coastguard Worker       PackFeatureProcessorOptions(&options);
92*993b0882SAndroid Build Coastguard Worker   ActionsFeatureProcessor feature_processor(
93*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
94*993b0882SAndroid Build Coastguard Worker           options_fb.data()),
95*993b0882SAndroid Build Coastguard Worker       &unilib_);
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker   Token token("Aaa", 0, 3);
98*993b0882SAndroid Build Coastguard Worker   std::vector<float> token_features;
99*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
100*993b0882SAndroid Build Coastguard Worker                                                     &token_features));
101*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features, SizeIs(5));
102*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features[4], FloatEq(1.0));
103*993b0882SAndroid Build Coastguard Worker }
104*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsFeatureProcessorTest,MultipleTokenEmbeddingsCaseFeature)105*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
106*993b0882SAndroid Build Coastguard Worker   ActionsTokenFeatureProcessorOptionsT options;
107*993b0882SAndroid Build Coastguard Worker   options.embedding_size = 4;
108*993b0882SAndroid Build Coastguard Worker   options.extract_case_feature = true;
109*993b0882SAndroid Build Coastguard Worker   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
110*993b0882SAndroid Build Coastguard Worker 
111*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer options_fb =
112*993b0882SAndroid Build Coastguard Worker       PackFeatureProcessorOptions(&options);
113*993b0882SAndroid Build Coastguard Worker   ActionsFeatureProcessor feature_processor(
114*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
115*993b0882SAndroid Build Coastguard Worker           options_fb.data()),
116*993b0882SAndroid Build Coastguard Worker       &unilib_);
117*993b0882SAndroid Build Coastguard Worker 
118*993b0882SAndroid Build Coastguard Worker   const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
119*993b0882SAndroid Build Coastguard Worker                                      Token("Cccc", 8, 12)};
120*993b0882SAndroid Build Coastguard Worker   std::vector<float> token_features;
121*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(feature_processor.AppendTokenFeatures(
122*993b0882SAndroid Build Coastguard Worker       tokens, &embedding_executor_, &token_features));
123*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features, SizeIs(15));
124*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features[4], FloatEq(1.0));
125*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features[9], FloatEq(-1.0));
126*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(token_features[14], FloatEq(1.0));
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker }  // namespace
130*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
131