1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite SVDF op.
16
17 #include <stdint.h>
18
19 #include <initializer_list>
20 #include <vector>
21
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26
27 namespace tflite {
28 namespace {
29
30 using ::testing::ElementsAreArray;
31
32 static float svdf_input[] = {
33 0.12609188, -0.46347019, -0.89598465,
34 0.35867718, 0.36897406, 0.73463392,
35
36 0.14278367, -1.64410412, -0.75222826,
37 -0.57290924, 0.12729003, 0.7567004,
38
39 0.49837467, 0.19278903, 0.26584083,
40 0.17660543, 0.52949083, -0.77931279,
41
42 -0.11186574, 0.13164264, -0.05349274,
43 -0.72674477, -0.5683046, 0.55900657,
44
45 -0.68892461, 0.37783599, 0.18263303,
46 -0.63690937, 0.44483393, -0.71817774,
47
48 -0.81299269, -0.86831826, 1.43940818,
49 -0.95760226, 1.82078898, 0.71135032,
50
51 -1.45006323, -0.82251364, -1.69082689,
52 -1.65087092, -1.89238167, 1.54172635,
53
54 0.03966608, -0.24936394, -0.77526885,
55 2.06740379, -1.51439476, 1.43768692,
56
57 0.11771342, -0.23761693, -0.65898693,
58 0.31088525, -1.55601168, -0.87661445,
59
60 -0.89477462, 1.67204106, -0.53235275,
61 -0.6230064, 0.29819036, 1.06939757,
62 };
63
64 static float svdf_golden_output_rank_1[] = {
65 0.014899, -0.0517661, -0.143725, -0.00271883,
66 -0.03004015, 0.09565311, 0.1587342, 0.00784263,
67
68 0.068281, -0.162217, -0.152268, 0.00323521,
69 0.01582633, 0.03858774, -0.03001583, -0.02671271,
70
71 -0.0317821, -0.0333089, 0.0609602, 0.0333759,
72 -0.01432795, 0.05524484, 0.1101355, -0.02382665,
73
74 -0.00623099, -0.077701, -0.391193, -0.0136691,
75 -0.02333033, 0.02293761, 0.12338032, 0.04326871,
76
77 0.201551, -0.164607, -0.179462, -0.0592739,
78 0.01064911, -0.17503069, 0.07821996, -0.00224009,
79
80 0.0886511, -0.0875401, -0.269283, 0.0281379,
81 -0.02282338, 0.09741908, 0.32973239, 0.12281385,
82
83 -0.201174, -0.586145, -0.628624, -0.0330412,
84 0.24780814, -0.39304617, -0.22473189, 0.02589256,
85
86 -0.0839096, -0.299329, 0.108746, 0.109808,
87 0.10084175, -0.06416984, 0.28936723, 0.0026358,
88
89 0.419114, -0.237824, -0.422627, 0.175115,
90 -0.2314795, -0.18584411, -0.4228974, -0.12928449,
91
92 0.36726, -0.522303, -0.456502, -0.175475,
93 0.17012937, -0.34447709, 0.38505614, -0.28158101,
94 };
95
96 static float svdf_golden_output_rank_2[] = {
97 -0.09623547, -0.10193135, 0.11083051, -0.0347917,
98 0.1141196, 0.12965347, -0.12652366, 0.01007236,
99
100 -0.16396809, -0.21247184, 0.11259045, -0.04156673,
101 0.10132131, -0.06143532, -0.00924693, 0.10084561,
102
103 0.01257364, 0.0506071, -0.19287863, -0.07162561,
104 -0.02033747, 0.22673416, 0.15487903, 0.02525555,
105
106 -0.1411963, -0.37054959, 0.01774767, 0.05867489,
107 0.09607603, -0.0141301, -0.08995658, 0.12867066,
108
109 -0.27142537, -0.16955489, 0.18521598, -0.12528358,
110 0.00331409, 0.11167502, 0.02218599, -0.07309391,
111
112 0.09593632, -0.28361851, -0.0773851, 0.17199151,
113 -0.00075242, 0.33691186, -0.1536046, 0.16572715,
114
115 -0.27916506, -0.27626723, 0.42615682, 0.3225764,
116 -0.37472126, -0.55655634, -0.05013514, 0.289112,
117
118 -0.24418658, 0.07540751, -0.1940318, -0.08911639,
119 0.00732617, 0.46737891, 0.26449674, 0.24888524,
120
121 -0.17225097, -0.54660404, -0.38795233, 0.08389944,
122 0.07736043, -0.28260678, 0.15666828, 1.14949894,
123
124 -0.57454878, -0.64704704, 0.73235172, -0.34616736,
125 0.21120001, -0.22927976, 0.02455296, -0.35906726,
126 };
127
128 // Derived class of SingleOpModel, which is used to test SVDF TFLite op.
129 class BaseSVDFOpModel : public SingleOpModel {
130 public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32,bool asymmetric_quantize_inputs=false)131 BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
132 int rank,
133 TensorType weights_feature_type = TensorType_FLOAT32,
134 TensorType weights_time_type = TensorType_FLOAT32,
135 bool asymmetric_quantize_inputs = false)
136 : batches_(batches),
137 units_(units),
138 input_size_(input_size),
139 memory_size_(memory_size),
140 rank_(rank) {
141 input_ = AddInput(TensorType_FLOAT32);
142 weights_feature_ = AddInput(weights_feature_type);
143 weights_time_ = AddInput(weights_time_type);
144 bias_ = AddNullInput();
145 const int num_filters = units * rank;
146 activation_state_ = AddVariableInput(
147 TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}});
148 output_ = AddOutput(TensorType_FLOAT32);
149 SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
150 CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE,
151 asymmetric_quantize_inputs)
152 .Union());
153 BuildInterpreter({
154 {batches_, input_size_}, // input tensor
155 {units_ * rank, input_size_}, // weights_feature tensor
156 {units_ * rank, memory_size_}, // weights_time tensor
157 {units_}, // bias tensor
158 {batches, memory_size * num_filters} // activation_state tensor
159 });
160 }
161
162 // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)163 void SetWeightsFeature(std::initializer_list<float> f) {
164 PopulateTensor(weights_feature_, f);
165 }
166
167 // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)168 void SetWeightsTime(std::initializer_list<float> f) {
169 PopulateTensor(weights_time_, f);
170 }
171
172 // Populates the input tensor.
SetInput(int offset,float * begin,float * end)173 void SetInput(int offset, float* begin, float* end) {
174 PopulateTensor(input_, offset, begin, end);
175 }
176
177 // Extracts the output tensor from the SVDF op.
GetOutput()178 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
179
input_size()180 int input_size() { return input_size_; }
num_units()181 int num_units() { return units_; }
num_batches()182 int num_batches() { return batches_; }
183
184 protected:
185 int input_;
186 int weights_feature_;
187 int weights_time_;
188 int bias_;
189 int activation_state_;
190 int output_;
191
192 int batches_;
193 int units_;
194 int input_size_;
195 int memory_size_;
196 int rank_;
197 };
198
199 class SVDFOpModel : public BaseSVDFOpModel {
200 public:
201 using BaseSVDFOpModel::BaseSVDFOpModel;
202 };
203
204 class HybridSVDFOpModel : public BaseSVDFOpModel {
205 public:
HybridSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType tensor_type,bool asymmetric_quantize_inputs)206 HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
207 int rank, TensorType tensor_type,
208 bool asymmetric_quantize_inputs)
209 : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
210 tensor_type, tensor_type, asymmetric_quantize_inputs) {
211 tensor_type_ = tensor_type;
212 }
213
SetWeights(int weights_idx,const std::vector<float> & f)214 void SetWeights(int weights_idx, const std::vector<float>& f) {
215 if (tensor_type_ == TensorType_UINT8) {
216 SymmetricQuantizeAndPopulate(weights_idx, f);
217 } else {
218 SignedSymmetricQuantizeAndPopulate(weights_idx, f);
219 }
220 }
221
SetWeightsFeature(std::initializer_list<float> f)222 void SetWeightsFeature(std::initializer_list<float> f) {
223 SetWeights(weights_feature_, f);
224 }
225
SetWeightsTime(std::initializer_list<float> f)226 void SetWeightsTime(std::initializer_list<float> f) {
227 SetWeights(weights_time_, f);
228 }
229
230 protected:
231 TensorType tensor_type_;
232 };
233
234 class SVDFOpTest : public ::testing::TestWithParam<bool> {
235 protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)236 void VerifyGoldens(float golden_input[], float golden_output[],
237 int golden_size, BaseSVDFOpModel* svdf,
238 float tolerance = 1e-5) {
239 const int svdf_num_batches = svdf->num_batches();
240 const int svdf_input_size = svdf->input_size();
241 const int svdf_num_units = svdf->num_units();
242 const int input_sequence_size =
243 golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
244 // Going over each input batch, setting the input tensor, invoking the SVDF
245 // op and checking the output with the expected golden values.
246 for (int i = 0; i < input_sequence_size; i++) {
247 float* batch_start =
248 golden_input + i * svdf_input_size * svdf_num_batches;
249 float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
250 svdf->SetInput(0, batch_start, batch_end);
251
252 ASSERT_EQ(svdf->Invoke(), kTfLiteOk);
253
254 const float* golden_start =
255 golden_output + i * svdf_num_units * svdf_num_batches;
256 const float* golden_end =
257 golden_start + svdf_num_units * svdf_num_batches;
258 std::vector<float> expected;
259 expected.insert(expected.end(), golden_start, golden_end);
260
261 EXPECT_THAT(svdf->GetOutput(),
262 ElementsAreArray(ArrayFloatNear(expected, tolerance)));
263 }
264 }
265 };
266
267 INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
268 ::testing::ValuesIn({false, true}));
269
TEST_F(SVDFOpTest,BlackBoxTestRank1)270 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
271 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
272 /*memory_size=*/10, /*rank=*/1);
273 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
274 0.22197971, 0.12416199, 0.27901134, 0.27557442,
275 0.3905206, -0.36137494, -0.06634006, -0.10640851});
276
277 svdf.SetWeightsTime(
278 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
279 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
280
281 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
282 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
283
284 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
285 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
286
287 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
288 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
289
290 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
291 &svdf);
292 }
293
TEST_F(SVDFOpTest,BlackBoxTestRank2)294 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
295 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
296 /*memory_size=*/10, /*rank=*/2);
297 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
298 0.12416199, 0.15785322, 0.27901134, 0.3905206,
299 0.21931258, -0.36137494, -0.10640851, 0.31053296,
300 -0.36118156, -0.0976817, -0.36916667, 0.22197971,
301 0.15294972, 0.38031587, 0.27557442, 0.39635518,
302 -0.21580373, -0.06634006, -0.02702999, 0.27072677});
303
304 svdf.SetWeightsTime(
305 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
306 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
307
308 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
309 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
310
311 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
312 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
313
314 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
315 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
316
317 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
318 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
319
320 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
321 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
322
323 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
324 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
325
326 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
327 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
328
329 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
330 &svdf);
331 }
332
TEST_P(SVDFOpTest,BlackBoxTestHybridRank1Uint8)333 TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
334 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
335 /*memory_size=*/10, /*rank=*/1, TensorType_UINT8,
336 GetParam());
337 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
338 0.22197971, 0.12416199, 0.27901134, 0.27557442,
339 0.3905206, -0.36137494, -0.06634006, -0.10640851});
340
341 svdf.SetWeightsTime(
342 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
343 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
344
345 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
346 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
347
348 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
349 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
350
351 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
352 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
353
354 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
355 &svdf,
356 /*tolerance=*/0.004285);
357 }
358
TEST_P(SVDFOpTest,BlackBoxTestHybridRank2Uint8)359 TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
360 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
361 /*memory_size=*/10, /*rank=*/2, TensorType_UINT8,
362 GetParam());
363 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
364 0.12416199, 0.15785322, 0.27901134, 0.3905206,
365 0.21931258, -0.36137494, -0.10640851, 0.31053296,
366 -0.36118156, -0.0976817, -0.36916667, 0.22197971,
367 0.15294972, 0.38031587, 0.27557442, 0.39635518,
368 -0.21580373, -0.06634006, -0.02702999, 0.27072677});
369
370 svdf.SetWeightsTime(
371 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
372 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
373
374 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
375 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
376
377 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
378 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
379
380 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
381 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
382
383 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
384 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
385
386 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
387 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
388
389 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
390 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
391
392 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
393 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
394
395 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
396 &svdf,
397 /*tolerance=*/0.007175);
398 }
399
TEST_P(SVDFOpTest,BlackBoxTestHybridRank1Int8)400 TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
401 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
402 /*memory_size=*/10, /*rank=*/1, TensorType_INT8,
403 GetParam());
404 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
405 0.22197971, 0.12416199, 0.27901134, 0.27557442,
406 0.3905206, -0.36137494, -0.06634006, -0.10640851});
407
408 svdf.SetWeightsTime(
409 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
410 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
411
412 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
413 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
414
415 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
416 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
417
418 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
419 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
420
421 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
422 &svdf,
423 /*tolerance=*/0.004285);
424 }
425
TEST_P(SVDFOpTest,BlackBoxTestHybridRank2Int8)426 TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
427 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
428 /*memory_size=*/10, /*rank=*/2, TensorType_INT8,
429 GetParam());
430 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
431 0.12416199, 0.15785322, 0.27901134, 0.3905206,
432 0.21931258, -0.36137494, -0.10640851, 0.31053296,
433 -0.36118156, -0.0976817, -0.36916667, 0.22197971,
434 0.15294972, 0.38031587, 0.27557442, 0.39635518,
435 -0.21580373, -0.06634006, -0.02702999, 0.27072677});
436
437 svdf.SetWeightsTime(
438 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
439 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
440
441 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
442 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
443
444 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
445 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
446
447 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
448 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
449
450 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
451 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
452
453 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
454 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
455
456 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
457 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
458
459 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
460 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
461
462 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
463 &svdf,
464 /*tolerance=*/0.007175);
465 }
466
467 // Test case for full integer quantization of SVDF.
468 class IntegerSVDFOpModel : public SingleOpModel {
469 public:
IntegerSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank)470 IntegerSVDFOpModel(int batches, int units, int input_size, int memory_size,
471 int rank)
472 : batches_(batches),
473 units_(units),
474 input_size_(input_size),
475 memory_size_(memory_size),
476 rank_(rank) {
477 const int num_filters = units * rank;
478 input_ = AddInput({TensorType_INT8, {batches, input_size}, -1, 1});
479 weights_feature_ =
480 AddInput({TensorType_INT8, {num_filters, input_size}, -0.5, 0.5});
481 weights_time_ =
482 AddInput({TensorType_INT16, {num_filters, memory_size}, -1, 1});
483 bias_ = AddInput({TensorType_INT32, {units}, -512, 512});
484 activation_state_ = AddVariableInput(
485 {TensorType_INT16, {batches, memory_size * num_filters}, -16, 16});
486 output_ = AddOutput({TensorType_INT8, {batches, units}, -0.5, 0.5});
487 SetBuiltinOp(
488 BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
489 CreateSVDFOptions(builder_, rank, ActivationFunctionType_RELU).Union());
490 BuildInterpreter({
491 {batches, input_size}, // input tensor
492 {num_filters, input_size}, // weights_feature tensor
493 {num_filters, memory_size}, // weights_time tensor
494 {units}, // bias tensor
495 {batches, memory_size * num_filters} // activation_state tensor
496 });
497 }
498
499 // Populates the weights_feature tensor.
SetWeightsFeature(const std::vector<float> & f)500 void SetWeightsFeature(const std::vector<float>& f) {
501 QuantizeAndPopulate<int8_t>(weights_feature_, f);
502 }
503
504 // Populates the weights_time tensor.
SetWeightsTime(const std::vector<float> & f)505 void SetWeightsTime(const std::vector<float>& f) {
506 QuantizeAndPopulate<int16_t>(weights_time_, f);
507 }
508
SetBias(const std::vector<float> & f)509 void SetBias(const std::vector<float>& f) {
510 QuantizeAndPopulate<int32_t>(bias_, f);
511 }
512
513 // Populates the input tensor.
SetInput(const std::vector<float> & f)514 void SetInput(const std::vector<float>& f) {
515 QuantizeAndPopulate<int8_t>(input_, f);
516 }
517
518 // Extracts the output tensor from the SVDF op.
GetOutput()519 std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
520
521 protected:
522 int input_;
523 int weights_feature_;
524 int weights_time_;
525 int bias_;
526 int activation_state_;
527 int output_;
528
529 int batches_;
530 int units_;
531 int input_size_;
532 int memory_size_;
533 int rank_;
534 };
535
TEST_F(SVDFOpTest,BlackBoxTestInteger)536 TEST_F(SVDFOpTest, BlackBoxTestInteger) {
537 IntegerSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
538 /*memory_size=*/10, /*rank=*/1);
539 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
540 0.22197971, 0.12416199, 0.27901134, 0.27557442,
541 0.3905206, -0.36137494, -0.06634006, -0.10640851});
542
543 svdf.SetWeightsTime(
544 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
545 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
546
547 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
548 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
549
550 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
551 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
552
553 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
554 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
555
556 svdf.SetBias({-0.0976817, 0.15294972, 0.39635518, -0.02702999});
557
558 const std::vector<std::vector<float>> input_sequences = {
559 {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
560 {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
561 0.73463392},
562 {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
563 0.7567004},
564 {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
565 {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
566 0.73463392},
567 {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
568 0.7567004},
569 {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
570 {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
571 0.73463392},
572 {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
573 0.7567004},
574 {0.49837467, 0.19278903, 0.26584083, 0.17660543, 0.52949083, -0.77931279},
575 {0.12609188, -0.46347019, -0.89598465, 0.35867718, 0.36897406,
576 0.73463392},
577 {0.14278367, -1.64410412, -0.75222826, -0.57290924, 0.12729003,
578 0.7567004}};
579
580 const std::vector<std::vector<int8_t>> expected_output = {
581 {-9, 24, 31, 1, -10, 10, -3, 0},
582 {2, 4, -44, -7, -10, 32, 52, 1},
583 {12, -17, 9, -8, 7, 16, -11, -8},
584 {-26, 29, 28, 16, -23, 26, 30, -6},
585 {-8, -25, -86, -5, -44, 59, 81, 15},
586 {62, -16, -37, 3, 27, 14, 34, -10},
587 {1, 24, -25, 23, 31, 61, 67, 11},
588 {-64, -65, -128, -25, -53, 59, 127, 20},
589 {20, -29, -20, -15, -28, 0, 8, -27},
590 {54, 61, -67, 38, 38, 64, 115, 0},
591 {-44, -75, -128, -20, -19, 93, 101, 35},
592 {-5, -56, 30, -18, -40, -9, -8, -31},
593 };
594
595 for (int sequence_index = 0; sequence_index < 12; ++sequence_index) {
596 svdf.SetInput(input_sequences[sequence_index]);
597 ASSERT_EQ(svdf.Invoke(), kTfLiteOk);
598 const std::vector<int8_t> res = svdf.GetOutput();
599 EXPECT_THAT(res, ElementsAreArray(expected_output[sequence_index]));
600 }
601 }
602
603 } // namespace
604 } // namespace tflite
605