xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/svdf_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite SVDF op.
16 
17 #include <stdint.h>
18 
19 #include <initializer_list>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 
27 namespace tflite {
28 namespace {
29 
30 using ::testing::ElementsAreArray;
31 
32 static float svdf_input[] = {
33     0.12609188,  -0.46347019, -0.89598465,
34     0.35867718,  0.36897406,  0.73463392,
35 
36     0.14278367,  -1.64410412, -0.75222826,
37     -0.57290924, 0.12729003,  0.7567004,
38 
39     0.49837467,  0.19278903,  0.26584083,
40     0.17660543,  0.52949083,  -0.77931279,
41 
42     -0.11186574, 0.13164264,  -0.05349274,
43     -0.72674477, -0.5683046,  0.55900657,
44 
45     -0.68892461, 0.37783599,  0.18263303,
46     -0.63690937, 0.44483393,  -0.71817774,
47 
48     -0.81299269, -0.86831826, 1.43940818,
49     -0.95760226, 1.82078898,  0.71135032,
50 
51     -1.45006323, -0.82251364, -1.69082689,
52     -1.65087092, -1.89238167, 1.54172635,
53 
54     0.03966608,  -0.24936394, -0.77526885,
55     2.06740379,  -1.51439476, 1.43768692,
56 
57     0.11771342,  -0.23761693, -0.65898693,
58     0.31088525,  -1.55601168, -0.87661445,
59 
60     -0.89477462, 1.67204106,  -0.53235275,
61     -0.6230064,  0.29819036,  1.06939757,
62 };
63 
64 static float svdf_golden_output_rank_1[] = {
65     0.014899,    -0.0517661,  -0.143725,   -0.00271883,
66     -0.03004015, 0.09565311,  0.1587342,   0.00784263,
67 
68     0.068281,    -0.162217,   -0.152268,   0.00323521,
69     0.01582633,  0.03858774,  -0.03001583, -0.02671271,
70 
71     -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
72     -0.01432795, 0.05524484,  0.1101355,   -0.02382665,
73 
74     -0.00623099, -0.077701,   -0.391193,   -0.0136691,
75     -0.02333033, 0.02293761,  0.12338032,  0.04326871,
76 
77     0.201551,    -0.164607,   -0.179462,   -0.0592739,
78     0.01064911,  -0.17503069, 0.07821996,  -0.00224009,
79 
80     0.0886511,   -0.0875401,  -0.269283,   0.0281379,
81     -0.02282338, 0.09741908,  0.32973239,  0.12281385,
82 
83     -0.201174,   -0.586145,   -0.628624,   -0.0330412,
84     0.24780814,  -0.39304617, -0.22473189, 0.02589256,
85 
86     -0.0839096,  -0.299329,   0.108746,    0.109808,
87     0.10084175,  -0.06416984, 0.28936723,  0.0026358,
88 
89     0.419114,    -0.237824,   -0.422627,   0.175115,
90     -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,
91 
92     0.36726,     -0.522303,   -0.456502,   -0.175475,
93     0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
94 };
95 
96 static float svdf_golden_output_rank_2[] = {
97     -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
98     0.1141196,   0.12965347,  -0.12652366, 0.01007236,
99 
100     -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
101     0.10132131,  -0.06143532, -0.00924693, 0.10084561,
102 
103     0.01257364,  0.0506071,   -0.19287863, -0.07162561,
104     -0.02033747, 0.22673416,  0.15487903,  0.02525555,
105 
106     -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
107     0.09607603,  -0.0141301,  -0.08995658, 0.12867066,
108 
109     -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
110     0.00331409,  0.11167502,  0.02218599,  -0.07309391,
111 
112     0.09593632,  -0.28361851, -0.0773851,  0.17199151,
113     -0.00075242, 0.33691186,  -0.1536046,  0.16572715,
114 
115     -0.27916506, -0.27626723, 0.42615682,  0.3225764,
116     -0.37472126, -0.55655634, -0.05013514, 0.289112,
117 
118     -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
119     0.00732617,  0.46737891,  0.26449674,  0.24888524,
120 
121     -0.17225097, -0.54660404, -0.38795233, 0.08389944,
122     0.07736043,  -0.28260678, 0.15666828,  1.14949894,
123 
124     -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
125     0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
126 };
127 
128 // Derived class of SingleOpModel, which is used to test SVDF TFLite op.
129 class BaseSVDFOpModel : public SingleOpModel {
130  public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32,bool asymmetric_quantize_inputs=false)131   BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
132                   int rank,
133                   TensorType weights_feature_type = TensorType_FLOAT32,
134                   TensorType weights_time_type = TensorType_FLOAT32,
135                   bool asymmetric_quantize_inputs = false)
136       : batches_(batches),
137         units_(units),
138         input_size_(input_size),
139         memory_size_(memory_size),
140         rank_(rank) {
141     input_ = AddInput(TensorType_FLOAT32);
142     weights_feature_ = AddInput(weights_feature_type);
143     weights_time_ = AddInput(weights_time_type);
144     bias_ = AddNullInput();
145     const int num_filters = units * rank;
146     activation_state_ = AddVariableInput(
147         TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}});
148     output_ = AddOutput(TensorType_FLOAT32);
149     SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
150                  CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE,
151                                    asymmetric_quantize_inputs)
152                      .Union());
153     BuildInterpreter({
154         {batches_, input_size_},              // input tensor
155         {units_ * rank, input_size_},         // weights_feature tensor
156         {units_ * rank, memory_size_},        // weights_time tensor
157         {units_},                             // bias tensor
158         {batches, memory_size * num_filters}  // activation_state tensor
159     });
160   }
161 
162   // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)163   void SetWeightsFeature(std::initializer_list<float> f) {
164     PopulateTensor(weights_feature_, f);
165   }
166 
167   // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)168   void SetWeightsTime(std::initializer_list<float> f) {
169     PopulateTensor(weights_time_, f);
170   }
171 
172   // Populates the input tensor.
SetInput(int offset,float * begin,float * end)173   void SetInput(int offset, float* begin, float* end) {
174     PopulateTensor(input_, offset, begin, end);
175   }
176 
177   // Extracts the output tensor from the SVDF op.
GetOutput()178   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
179 
input_size()180   int input_size() { return input_size_; }
num_units()181   int num_units() { return units_; }
num_batches()182   int num_batches() { return batches_; }
183 
184  protected:
185   int input_;
186   int weights_feature_;
187   int weights_time_;
188   int bias_;
189   int activation_state_;
190   int output_;
191 
192   int batches_;
193   int units_;
194   int input_size_;
195   int memory_size_;
196   int rank_;
197 };
198 
199 class SVDFOpModel : public BaseSVDFOpModel {
200  public:
201   using BaseSVDFOpModel::BaseSVDFOpModel;
202 };
203 
204 class HybridSVDFOpModel : public BaseSVDFOpModel {
205  public:
HybridSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType tensor_type,bool asymmetric_quantize_inputs)206   HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
207                     int rank, TensorType tensor_type,
208                     bool asymmetric_quantize_inputs)
209       : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
210                         tensor_type, tensor_type, asymmetric_quantize_inputs) {
211     tensor_type_ = tensor_type;
212   }
213 
SetWeights(int weights_idx,const std::vector<float> & f)214   void SetWeights(int weights_idx, const std::vector<float>& f) {
215     if (tensor_type_ == TensorType_UINT8) {
216       SymmetricQuantizeAndPopulate(weights_idx, f);
217     } else {
218       SignedSymmetricQuantizeAndPopulate(weights_idx, f);
219     }
220   }
221 
SetWeightsFeature(std::initializer_list<float> f)222   void SetWeightsFeature(std::initializer_list<float> f) {
223     SetWeights(weights_feature_, f);
224   }
225 
SetWeightsTime(std::initializer_list<float> f)226   void SetWeightsTime(std::initializer_list<float> f) {
227     SetWeights(weights_time_, f);
228   }
229 
230  protected:
231   TensorType tensor_type_;
232 };
233 
234 class SVDFOpTest : public ::testing::TestWithParam<bool> {
235  protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)236   void VerifyGoldens(float golden_input[], float golden_output[],
237                      int golden_size, BaseSVDFOpModel* svdf,
238                      float tolerance = 1e-5) {
239     const int svdf_num_batches = svdf->num_batches();
240     const int svdf_input_size = svdf->input_size();
241     const int svdf_num_units = svdf->num_units();
242     const int input_sequence_size =
243         golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
244     // Going over each input batch, setting the input tensor, invoking the SVDF
245     // op and checking the output with the expected golden values.
246     for (int i = 0; i < input_sequence_size; i++) {
247       float* batch_start =
248           golden_input + i * svdf_input_size * svdf_num_batches;
249       float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
250       svdf->SetInput(0, batch_start, batch_end);
251 
252       ASSERT_EQ(svdf->Invoke(), kTfLiteOk);
253 
254       const float* golden_start =
255           golden_output + i * svdf_num_units * svdf_num_batches;
256       const float* golden_end =
257           golden_start + svdf_num_units * svdf_num_batches;
258       std::vector<float> expected;
259       expected.insert(expected.end(), golden_start, golden_end);
260 
261       EXPECT_THAT(svdf->GetOutput(),
262                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
263     }
264   }
265 };
266 
267 INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
268                          ::testing::ValuesIn({false, true}));
269 
TEST_F(SVDFOpTest,BlackBoxTestRank1)270 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
271   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
272                    /*memory_size=*/10, /*rank=*/1);
273   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
274                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
275                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
276 
277   svdf.SetWeightsTime(
278       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
279        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
280 
281        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
282        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
283 
284        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
285        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
286 
287        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
288        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
289 
290   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
291                 &svdf);
292 }
293 
TEST_F(SVDFOpTest,BlackBoxTestRank2)294 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
295   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
296                    /*memory_size=*/10, /*rank=*/2);
297   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
298                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
299                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
300                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
301                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
302                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
303 
304   svdf.SetWeightsTime(
305       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
306        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
307 
308        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
309        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
310 
311        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
312        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
313 
314        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
315        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
316 
317        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
318        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
319 
320        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
321        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
322 
323        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
324        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
325 
326        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
327        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
328 
329   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
330                 &svdf);
331 }
332 
TEST_P(SVDFOpTest,BlackBoxTestHybridRank1Uint8)333 TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
334   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
335                          /*memory_size=*/10, /*rank=*/1, TensorType_UINT8,
336                          GetParam());
337   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
338                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
339                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
340 
341   svdf.SetWeightsTime(
342       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
343        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
344 
345        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
346        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
347 
348        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
349        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
350 
351        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
352        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
353 
354   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
355                 &svdf,
356                 /*tolerance=*/0.004285);
357 }
358 
TEST_P(SVDFOpTest,BlackBoxTestHybridRank2Uint8)359 TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
360   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
361                          /*memory_size=*/10, /*rank=*/2, TensorType_UINT8,
362                          GetParam());
363   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
364                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
365                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
366                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
367                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
368                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
369 
370   svdf.SetWeightsTime(
371       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
372        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
373 
374        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
375        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
376 
377        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
378        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
379 
380        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
381        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
382 
383        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
384        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
385 
386        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
387        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
388 
389        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
390        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
391 
392        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
393        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
394 
395   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
396                 &svdf,
397                 /*tolerance=*/0.007175);
398 }
399 
TEST_P(SVDFOpTest,BlackBoxTestHybridRank1Int8)400 TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
401   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
402                          /*memory_size=*/10, /*rank=*/1, TensorType_INT8,
403                          GetParam());
404   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
405                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
406                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
407 
408   svdf.SetWeightsTime(
409       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
410        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
411 
412        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
413        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
414 
415        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
416        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
417 
418        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
419        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
420 
421   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
422                 &svdf,
423                 /*tolerance=*/0.004285);
424 }
425 
TEST_P(SVDFOpTest,BlackBoxTestHybridRank2Int8)426 TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
427   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
428                          /*memory_size=*/10, /*rank=*/2, TensorType_INT8,
429                          GetParam());
430   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
431                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
432                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
433                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
434                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
435                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
436 
437   svdf.SetWeightsTime(
438       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
439        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
440 
441        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
442        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
443 
444        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
445        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
446 
447        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
448        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
449 
450        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
451        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
452 
453        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
454        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
455 
456        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
457        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
458 
459        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
460        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
461 
462   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
463                 &svdf,
464                 /*tolerance=*/0.007175);
465 }
466 
467 // Test case for full integer quantization of SVDF.
468 class IntegerSVDFOpModel : public SingleOpModel {
469  public:
IntegerSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank)470   IntegerSVDFOpModel(int batches, int units, int input_size, int memory_size,
471                      int rank)
472       : batches_(batches),
473         units_(units),
474         input_size_(input_size),
475         memory_size_(memory_size),
476         rank_(rank) {
477     const int num_filters = units * rank;
478     input_ = AddInput({TensorType_INT8, {batches, input_size}, -1, 1});
479     weights_feature_ =
480         AddInput({TensorType_INT8, {num_filters, input_size}, -0.5, 0.5});
481     weights_time_ =
482         AddInput({TensorType_INT16, {num_filters, memory_size}, -1, 1});
483     bias_ = AddInput({TensorType_INT32, {units}, -512, 512});
484     activation_state_ = AddVariableInput(
485         {TensorType_INT16, {batches, memory_size * num_filters}, -16, 16});
486     output_ = AddOutput({TensorType_INT8, {batches, units}, -0.5, 0.5});
487     SetBuiltinOp(
488         BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
489         CreateSVDFOptions(builder_, rank, ActivationFunctionType_RELU).Union());
490     BuildInterpreter({
491         {batches, input_size},                // input tensor
492         {num_filters, input_size},            // weights_feature tensor
493         {num_filters, memory_size},           // weights_time tensor
494         {units},                              // bias tensor
495         {batches, memory_size * num_filters}  // activation_state tensor
496     });
497   }
498 
499   // Populates the weights_feature tensor.
SetWeightsFeature(const std::vector<float> & f)500   void SetWeightsFeature(const std::vector<float>& f) {
501     QuantizeAndPopulate<int8_t>(weights_feature_, f);
502   }
503 
504   // Populates the weights_time tensor.
SetWeightsTime(const std::vector<float> & f)505   void SetWeightsTime(const std::vector<float>& f) {
506     QuantizeAndPopulate<int16_t>(weights_time_, f);
507   }
508 
SetBias(const std::vector<float> & f)509   void SetBias(const std::vector<float>& f) {
510     QuantizeAndPopulate<int32_t>(bias_, f);
511   }
512 
513   // Populates the input tensor.
SetInput(const std::vector<float> & f)514   void SetInput(const std::vector<float>& f) {
515     QuantizeAndPopulate<int8_t>(input_, f);
516   }
517 
518   // Extracts the output tensor from the SVDF op.
GetOutput()519   std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
520 
521  protected:
522   int input_;
523   int weights_feature_;
524   int weights_time_;
525   int bias_;
526   int activation_state_;
527   int output_;
528 
529   int batches_;
530   int units_;
531   int input_size_;
532   int memory_size_;
533   int rank_;
534 };
535 
TEST_F(SVDFOpTest,BlackBoxTestInteger)536 TEST_F(SVDFOpTest, BlackBoxTestInteger) {
537   IntegerSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
538                           /*memory_size=*/10, /*rank=*/1);
539   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
540                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
541                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
542 
543   svdf.SetWeightsTime(
544       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
545        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
546 
547        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
548        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
549 
550        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
551        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
552 
553        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
554        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
555 
556   svdf.SetBias({-0.0976817, 0.15294972, 0.39635518, -0.02702999});
557 
558   const std::vector<std::vector<float>> input_sequences = {
559       {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
560       {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
561        0.73463392},
562       {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
563        0.7567004},
564       {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
565       {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
566        0.73463392},
567       {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
568        0.7567004},
569       {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
570       {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
571        0.73463392},
572       {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
573        0.7567004},
574       {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
575       {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
576        0.73463392},
577       {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
578        0.7567004}};
579 
580   const std::vector<std::vector<int8_t>> expected_output = {
581       {-9, 24, 31, 1, -10, 10, -3, 0},
582       {2, 4, -44, -7, -10, 32, 52, 1},
583       {12, -17, 9, -8, 7, 16, -11, -8},
584       {-26, 29, 28, 16, -23, 26, 30, -6},
585       {-8, -25, -86, -5, -44, 59, 81, 15},
586       {62, -16, -37, 3, 27, 14, 34, -10},
587       {1, 24, -25, 23, 31, 61, 67, 11},
588       {-64, -65, -128, -25, -53, 59, 127, 20},
589       {20, -29, -20, -15, -28, 0, 8, -27},
590       {54, 61, -67, 38, 38, 64, 115, 0},
591       {-44, -75, -128, -20, -19, 93, 101, 35},
592       {-5, -56, 30, -18, -40, -9, -8, -31},
593   };
594 
595   for (int sequence_index = 0; sequence_index < 12; ++sequence_index) {
596     svdf.SetInput(input_sequences[sequence_index]);
597     ASSERT_EQ(svdf.Invoke(), kTfLiteOk);
598     const std::vector<int8_t> res = svdf.GetOutput();
599     EXPECT_THAT(res, ElementsAreArray(expected_output[sequence_index]));
600   }
601 }
602 
603 }  // namespace
604 }  // namespace tflite
605