1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
20*993b0882SAndroid Build Coastguard Worker #include "annotator/model-executor.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h"
22*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
23*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
24*993b0882SAndroid Build Coastguard Worker
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker
28*993b0882SAndroid Build Coastguard Worker using ::testing::FloatEq;
29*993b0882SAndroid Build Coastguard Worker using ::testing::SizeIs;
30*993b0882SAndroid Build Coastguard Worker
31*993b0882SAndroid Build Coastguard Worker // EmbeddingExecutor that always returns features based on
32*993b0882SAndroid Build Coastguard Worker // the id of the sparse features.
33*993b0882SAndroid Build Coastguard Worker class FakeEmbeddingExecutor : public EmbeddingExecutor {
34*993b0882SAndroid Build Coastguard Worker public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const35*993b0882SAndroid Build Coastguard Worker bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
36*993b0882SAndroid Build Coastguard Worker const int dest_size) const override {
37*993b0882SAndroid Build Coastguard Worker TC3_CHECK_GE(dest_size, 4);
38*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(sparse_features, SizeIs(1));
39*993b0882SAndroid Build Coastguard Worker dest[0] = sparse_features.data()[0];
40*993b0882SAndroid Build Coastguard Worker dest[1] = sparse_features.data()[0];
41*993b0882SAndroid Build Coastguard Worker dest[2] = -sparse_features.data()[0];
42*993b0882SAndroid Build Coastguard Worker dest[3] = -sparse_features.data()[0];
43*993b0882SAndroid Build Coastguard Worker return true;
44*993b0882SAndroid Build Coastguard Worker }
45*993b0882SAndroid Build Coastguard Worker
46*993b0882SAndroid Build Coastguard Worker private:
47*993b0882SAndroid Build Coastguard Worker std::vector<float> storage_;
48*993b0882SAndroid Build Coastguard Worker };
49*993b0882SAndroid Build Coastguard Worker
50*993b0882SAndroid Build Coastguard Worker class ActionsFeatureProcessorTest : public ::testing::Test {
51*993b0882SAndroid Build Coastguard Worker protected:
ActionsFeatureProcessorTest()52*993b0882SAndroid Build Coastguard Worker ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
53*993b0882SAndroid Build Coastguard Worker
PackFeatureProcessorOptions(ActionsTokenFeatureProcessorOptionsT * options) const54*993b0882SAndroid Build Coastguard Worker flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
55*993b0882SAndroid Build Coastguard Worker ActionsTokenFeatureProcessorOptionsT* options) const {
56*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
57*993b0882SAndroid Build Coastguard Worker builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
58*993b0882SAndroid Build Coastguard Worker return builder.Release();
59*993b0882SAndroid Build Coastguard Worker }
60*993b0882SAndroid Build Coastguard Worker
61*993b0882SAndroid Build Coastguard Worker FakeEmbeddingExecutor embedding_executor_;
62*993b0882SAndroid Build Coastguard Worker UniLib unilib_;
63*993b0882SAndroid Build Coastguard Worker };
64*993b0882SAndroid Build Coastguard Worker
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddings)65*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
66*993b0882SAndroid Build Coastguard Worker ActionsTokenFeatureProcessorOptionsT options;
67*993b0882SAndroid Build Coastguard Worker options.embedding_size = 4;
68*993b0882SAndroid Build Coastguard Worker options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
69*993b0882SAndroid Build Coastguard Worker
70*993b0882SAndroid Build Coastguard Worker flatbuffers::DetachedBuffer options_fb =
71*993b0882SAndroid Build Coastguard Worker PackFeatureProcessorOptions(&options);
72*993b0882SAndroid Build Coastguard Worker ActionsFeatureProcessor feature_processor(
73*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
74*993b0882SAndroid Build Coastguard Worker options_fb.data()),
75*993b0882SAndroid Build Coastguard Worker &unilib_);
76*993b0882SAndroid Build Coastguard Worker
77*993b0882SAndroid Build Coastguard Worker Token token("aaa", 0, 3);
78*993b0882SAndroid Build Coastguard Worker std::vector<float> token_features;
79*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
80*993b0882SAndroid Build Coastguard Worker &token_features));
81*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features, SizeIs(4));
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddingsCaseFeature)84*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
85*993b0882SAndroid Build Coastguard Worker ActionsTokenFeatureProcessorOptionsT options;
86*993b0882SAndroid Build Coastguard Worker options.embedding_size = 4;
87*993b0882SAndroid Build Coastguard Worker options.extract_case_feature = true;
88*993b0882SAndroid Build Coastguard Worker options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
89*993b0882SAndroid Build Coastguard Worker
90*993b0882SAndroid Build Coastguard Worker flatbuffers::DetachedBuffer options_fb =
91*993b0882SAndroid Build Coastguard Worker PackFeatureProcessorOptions(&options);
92*993b0882SAndroid Build Coastguard Worker ActionsFeatureProcessor feature_processor(
93*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
94*993b0882SAndroid Build Coastguard Worker options_fb.data()),
95*993b0882SAndroid Build Coastguard Worker &unilib_);
96*993b0882SAndroid Build Coastguard Worker
97*993b0882SAndroid Build Coastguard Worker Token token("Aaa", 0, 3);
98*993b0882SAndroid Build Coastguard Worker std::vector<float> token_features;
99*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
100*993b0882SAndroid Build Coastguard Worker &token_features));
101*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features, SizeIs(5));
102*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features[4], FloatEq(1.0));
103*993b0882SAndroid Build Coastguard Worker }
104*993b0882SAndroid Build Coastguard Worker
TEST_F(ActionsFeatureProcessorTest,MultipleTokenEmbeddingsCaseFeature)105*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
106*993b0882SAndroid Build Coastguard Worker ActionsTokenFeatureProcessorOptionsT options;
107*993b0882SAndroid Build Coastguard Worker options.embedding_size = 4;
108*993b0882SAndroid Build Coastguard Worker options.extract_case_feature = true;
109*993b0882SAndroid Build Coastguard Worker options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
110*993b0882SAndroid Build Coastguard Worker
111*993b0882SAndroid Build Coastguard Worker flatbuffers::DetachedBuffer options_fb =
112*993b0882SAndroid Build Coastguard Worker PackFeatureProcessorOptions(&options);
113*993b0882SAndroid Build Coastguard Worker ActionsFeatureProcessor feature_processor(
114*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
115*993b0882SAndroid Build Coastguard Worker options_fb.data()),
116*993b0882SAndroid Build Coastguard Worker &unilib_);
117*993b0882SAndroid Build Coastguard Worker
118*993b0882SAndroid Build Coastguard Worker const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
119*993b0882SAndroid Build Coastguard Worker Token("Cccc", 8, 12)};
120*993b0882SAndroid Build Coastguard Worker std::vector<float> token_features;
121*993b0882SAndroid Build Coastguard Worker EXPECT_TRUE(feature_processor.AppendTokenFeatures(
122*993b0882SAndroid Build Coastguard Worker tokens, &embedding_executor_, &token_features));
123*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features, SizeIs(15));
124*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features[4], FloatEq(1.0));
125*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features[9], FloatEq(-1.0));
126*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(token_features[14], FloatEq(1.0));
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker
129*993b0882SAndroid Build Coastguard Worker } // namespace
130*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
131