1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite LSTM op.
16 //
17 // TODO(alanchiao): add unit test with invalid input dimensions for this and its
18 // variants.
19
20 #include <stdint.h>
21
22 #include <utility>
23 #include <vector>
24
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/test_util.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31
32 namespace tflite {
33 namespace {
34
35 using ::testing::ElementsAreArray;
36
37 class LSTMOpModel : public SingleOpModel {
38 public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,const TensorType weight_type,bool model_has_legacy_20_inputs,bool is_layer_norm,bool asymmetric_quantize_inputs)39 LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
40 bool use_peephole, bool use_projection_weights,
41 bool use_projection_bias, const TensorType weight_type,
42 bool model_has_legacy_20_inputs, bool is_layer_norm,
43 bool asymmetric_quantize_inputs)
44 : n_input_(n_input),
45 n_output_(n_output),
46 n_batch_(n_batch),
47 weight_type_(weight_type) {
48 input_ = AddInput({TensorType_FLOAT32, {n_batch, n_input}});
49
50 if (use_cifg) {
51 input_to_input_weights_ = AddNullInput();
52 } else {
53 input_to_input_weights_ = AddInput({weight_type, {n_cell, n_input}});
54 }
55 input_to_forget_weights_ = AddInput({weight_type, {n_cell, n_input}});
56 input_to_cell_weights_ = AddInput({weight_type, {n_cell, n_input}});
57 input_to_output_weights_ = AddInput({weight_type, {n_cell, n_input}});
58
59 if (use_cifg) {
60 recurrent_to_input_weights_ = AddNullInput();
61 } else {
62 recurrent_to_input_weights_ = AddInput({weight_type, {n_cell, n_output}});
63 }
64 recurrent_to_forget_weights_ = AddInput({weight_type, {n_cell, n_output}});
65 recurrent_to_cell_weights_ = AddInput({weight_type, {n_cell, n_output}});
66 recurrent_to_output_weights_ = AddInput({weight_type, {n_cell, n_output}});
67
68 if (use_peephole) {
69 if (use_cifg) {
70 cell_to_input_weights_ = AddNullInput();
71 } else {
72 cell_to_input_weights_ = AddInput({weight_type, {n_cell}});
73 }
74 cell_to_forget_weights_ = AddInput({weight_type, {n_cell}});
75 cell_to_output_weights_ = AddInput({weight_type, {n_cell}});
76 } else {
77 cell_to_input_weights_ = AddNullInput();
78 cell_to_forget_weights_ = AddNullInput();
79 cell_to_output_weights_ = AddNullInput();
80 }
81
82 if (use_cifg) {
83 input_gate_bias_ = AddNullInput();
84 } else {
85 input_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
86 }
87 forget_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
88 cell_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
89 output_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
90
91 if (use_projection_weights) {
92 projection_weights_ = AddInput({weight_type, {n_output, n_cell}});
93 } else {
94 projection_weights_ = AddNullInput();
95 }
96 if (use_projection_bias) {
97 CHECK(use_projection_weights);
98 projection_bias_ = AddInput({TensorType_FLOAT32, {n_output}});
99 } else {
100 projection_bias_ = AddNullInput();
101 }
102
103 // Adding the 2 state tensors.
104 AddVariableInput({TensorType_FLOAT32, {n_batch, n_output}});
105 AddVariableInput({TensorType_FLOAT32, {n_batch, n_cell}});
106
107 // Layer norm weights.
108 if (!model_has_legacy_20_inputs) {
109 if (is_layer_norm) {
110 if (use_cifg) {
111 input_layer_norm_coefficients_ = AddNullInput();
112 } else {
113 input_layer_norm_coefficients_ =
114 AddInput({TensorType_FLOAT32, {n_cell}});
115 }
116 forget_layer_norm_coefficients_ =
117 AddInput({TensorType_FLOAT32, {n_cell}});
118 cell_layer_norm_coefficients_ =
119 AddInput({TensorType_FLOAT32, {n_cell}});
120 output_layer_norm_coefficients_ =
121 AddInput({TensorType_FLOAT32, {n_cell}});
122 } else {
123 input_layer_norm_coefficients_ = AddNullInput();
124 forget_layer_norm_coefficients_ = AddNullInput();
125 cell_layer_norm_coefficients_ = AddNullInput();
126 output_layer_norm_coefficients_ = AddNullInput();
127 }
128 }
129
130 output_ = AddOutput({TensorType_FLOAT32, {n_batch, n_output}});
131
132 // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
133 // default 0.
134 SetBuiltinOp(
135 BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
136 CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
137 /*cell_clip=*/0.0f, /*proj_clip=*/0.0f,
138 LSTMKernelType_FULL, asymmetric_quantize_inputs)
139 .Union());
140
141 // Input shapes are already set up, no need to pass them again.
142 BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
143 /*allow_fp32_relax_to_fp16=*/false,
144 /*apply_delegate=*/false);
145 }
146
SetInputToInputWeights(const std::vector<float> & f)147 void SetInputToInputWeights(const std::vector<float>& f) {
148 SetWeights(input_to_input_weights_, f);
149 }
150
SetInputToForgetWeights(const std::vector<float> & f)151 void SetInputToForgetWeights(const std::vector<float>& f) {
152 SetWeights(input_to_forget_weights_, f);
153 }
154
SetInputToCellWeights(const std::vector<float> & f)155 void SetInputToCellWeights(const std::vector<float>& f) {
156 SetWeights(input_to_cell_weights_, f);
157 }
158
SetInputToOutputWeights(const std::vector<float> & f)159 void SetInputToOutputWeights(const std::vector<float>& f) {
160 SetWeights(input_to_output_weights_, f);
161 }
162
SetRecurrentToInputWeights(const std::vector<float> & f)163 void SetRecurrentToInputWeights(const std::vector<float>& f) {
164 SetWeights(recurrent_to_input_weights_, f);
165 }
166
SetRecurrentToForgetWeights(const std::vector<float> & f)167 void SetRecurrentToForgetWeights(const std::vector<float>& f) {
168 SetWeights(recurrent_to_forget_weights_, f);
169 }
170
SetRecurrentToCellWeights(const std::vector<float> & f)171 void SetRecurrentToCellWeights(const std::vector<float>& f) {
172 SetWeights(recurrent_to_cell_weights_, f);
173 }
174
SetRecurrentToOutputWeights(const std::vector<float> & f)175 void SetRecurrentToOutputWeights(const std::vector<float>& f) {
176 SetWeights(recurrent_to_output_weights_, f);
177 }
178
SetCellToInputWeights(const std::vector<float> & f)179 void SetCellToInputWeights(const std::vector<float>& f) {
180 SetWeights(cell_to_input_weights_, f);
181 }
182
SetCellToForgetWeights(const std::vector<float> & f)183 void SetCellToForgetWeights(const std::vector<float>& f) {
184 SetWeights(cell_to_forget_weights_, f);
185 }
186
SetCellToOutputWeights(const std::vector<float> & f)187 void SetCellToOutputWeights(const std::vector<float>& f) {
188 SetWeights(cell_to_output_weights_, f);
189 }
190
SetInputLayerNormCoefficients(const std::vector<float> & f)191 void SetInputLayerNormCoefficients(const std::vector<float>& f) {
192 PopulateTensor(input_layer_norm_coefficients_, f);
193 }
194
SetForgetLayerNormCoefficients(const std::vector<float> & f)195 void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
196 PopulateTensor(forget_layer_norm_coefficients_, f);
197 }
198
SetCellLayerNormCoefficients(const std::vector<float> & f)199 void SetCellLayerNormCoefficients(const std::vector<float>& f) {
200 PopulateTensor(cell_layer_norm_coefficients_, f);
201 }
202
SetOutputLayerNormCoefficients(const std::vector<float> & f)203 void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
204 PopulateTensor(output_layer_norm_coefficients_, f);
205 }
206
SetInputGateBias(const std::vector<float> & f)207 void SetInputGateBias(const std::vector<float>& f) {
208 PopulateTensor(input_gate_bias_, f);
209 }
210
SetForgetGateBias(const std::vector<float> & f)211 void SetForgetGateBias(const std::vector<float>& f) {
212 PopulateTensor(forget_gate_bias_, f);
213 }
214
SetCellBias(const std::vector<float> & f)215 void SetCellBias(const std::vector<float>& f) {
216 PopulateTensor(cell_gate_bias_, f);
217 }
218
SetOutputGateBias(const std::vector<float> & f)219 void SetOutputGateBias(const std::vector<float>& f) {
220 PopulateTensor(output_gate_bias_, f);
221 }
222
SetProjectionWeights(const std::vector<float> & f)223 void SetProjectionWeights(const std::vector<float>& f) {
224 SetWeights(projection_weights_, f);
225 }
226
SetProjectionBias(const std::vector<float> & f)227 void SetProjectionBias(const std::vector<float>& f) {
228 PopulateTensor(projection_bias_, f);
229 }
230
SetInput(int offset,const float * begin,const float * end)231 void SetInput(int offset, const float* begin, const float* end) {
232 SingleOpModel::PopulateTensor(input_, offset, const_cast<float*>(begin),
233 const_cast<float*>(end));
234 }
235
GetOutput()236 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
237
num_inputs()238 int num_inputs() { return n_input_; }
num_outputs()239 int num_outputs() { return n_output_; }
num_batches()240 int num_batches() { return n_batch_; }
241
242 protected:
243 int input_;
244 int input_to_input_weights_;
245 int input_to_forget_weights_;
246 int input_to_cell_weights_;
247 int input_to_output_weights_;
248
249 int recurrent_to_input_weights_;
250 int recurrent_to_forget_weights_;
251 int recurrent_to_cell_weights_;
252 int recurrent_to_output_weights_;
253
254 int cell_to_input_weights_;
255 int cell_to_forget_weights_;
256 int cell_to_output_weights_;
257
258 int input_layer_norm_coefficients_ = kTfLiteOptionalTensor;
259 int forget_layer_norm_coefficients_ = kTfLiteOptionalTensor;
260 int cell_layer_norm_coefficients_ = kTfLiteOptionalTensor;
261 int output_layer_norm_coefficients_ = kTfLiteOptionalTensor;
262
263 int input_gate_bias_;
264 int forget_gate_bias_;
265 int cell_gate_bias_;
266 int output_gate_bias_;
267
268 int projection_weights_;
269 int projection_bias_;
270
271 int output_;
272
273 int n_input_;
274 int n_output_;
275 int n_batch_;
276
277 private:
PopulateTensor(int index,const std::vector<float> & data)278 void PopulateTensor(int index, const std::vector<float>& data) {
279 // Nothing to do if tensor is an optional input or if data vector is empty.
280 if ((index == kTfLiteOptionalTensor) || data.empty()) return;
281 SingleOpModel::PopulateTensor(index, data);
282 }
283
SetWeights(int index,const std::vector<float> & data)284 void SetWeights(int index, const std::vector<float>& data) {
285 if (data.empty()) return;
286 if (index == kTfLiteOptionalTensor) return;
287 switch (weight_type_) {
288 case TensorType_FLOAT32:
289 PopulateTensor(index, data);
290 break;
291 case TensorType_UINT8:
292 SymmetricQuantizeAndPopulate(index, data);
293 break;
294 case TensorType_INT8:
295 SignedSymmetricQuantizeAndPopulate(index, data);
296 break;
297 default:
298 GTEST_FAIL() << "Type not supported: " << weight_type_;
299 break;
300 }
301 }
302
303 const TensorType weight_type_;
304 };
305
306 // Parameters:
307 // std::get<0>(GetParam()) => weight_type
308 // std::get<1>(GetParam()) => model_has_legacy_20_inputs
309 // std::get<2>(GetParam()) => asymmetric_quantize_inputs
310 class LstmOpTest
311 : public ::testing::TestWithParam<std::tuple<TensorType, bool, bool>> {
312 protected:
313 // Weights of the LSTM model. Some are optional.
314 std::vector<float> input_to_input_weights_;
315 std::vector<float> input_to_cell_weights_;
316 std::vector<float> input_to_forget_weights_;
317 std::vector<float> input_to_output_weights_;
318 std::vector<float> input_gate_bias_;
319 std::vector<float> cell_gate_bias_;
320 std::vector<float> forget_gate_bias_;
321 std::vector<float> output_gate_bias_;
322 std::vector<float> recurrent_to_input_weights_;
323 std::vector<float> recurrent_to_cell_weights_;
324 std::vector<float> recurrent_to_forget_weights_;
325 std::vector<float> recurrent_to_output_weights_;
326 std::vector<float> cell_to_input_weights_;
327 std::vector<float> cell_to_forget_weights_;
328 std::vector<float> cell_to_output_weights_;
329 std::vector<float> projection_weights_;
330 std::vector<float> input_layer_norm_coefficients_;
331 std::vector<float> forget_layer_norm_coefficients_;
332 std::vector<float> cell_layer_norm_coefficients_;
333 std::vector<float> output_layer_norm_coefficients_;
334
335 // LSTM input is stored as num_steps * num_batch * num_inputs vector.
336 std::vector<std::vector<std::vector<float>>> lstm_input_;
337 // LSTM output is stored as num_steps * num_batch * num_outputs vector.
338 std::vector<std::vector<std::vector<float>>> lstm_golden_output_;
339
340 // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(LSTMOpModel * lstm,float tolerance)341 void VerifyGoldens(LSTMOpModel* lstm, float tolerance) {
342 // The delegate, if used, needs to know the scales and zero-points of
343 // quantized tensors, which are computed dynamically when weights are set,
344 // so weights have to be set before applying the delegate.
345 SetAllWeightsAndBiases(lstm);
346 lstm->ApplyDelegate();
347
348 const int num_inputs = lstm->num_inputs();
349 const int num_outputs = lstm->num_outputs();
350 const int num_batches = lstm->num_batches();
351
352 ASSERT_EQ(lstm_input_.size(), lstm_golden_output_.size());
353 const int num_steps = lstm_input_.size();
354
355 for (int i = 0; i < num_steps; ++i) {
356 ASSERT_EQ(num_batches, lstm_input_[i].size());
357 for (int b = 0; b < num_batches; ++b) {
358 ASSERT_EQ(num_inputs, lstm_input_[i][b].size());
359 const float* batch_start = lstm_input_[i][b].data();
360 const float* batch_end = batch_start + num_inputs;
361 lstm->SetInput(b * num_inputs, batch_start, batch_end);
362 }
363
364 ASSERT_EQ(lstm->Invoke(), kTfLiteOk);
365
366 std::vector<float> expected;
367 ASSERT_EQ(num_batches, lstm_golden_output_[i].size());
368 for (int b = 0; b < num_batches; ++b) {
369 ASSERT_EQ(num_outputs, lstm_golden_output_[i][b].size());
370 const float* batch_start = lstm_golden_output_[i][b].data();
371 const float* batch_end = batch_start + num_outputs;
372 expected.insert(expected.end(), batch_start, batch_end);
373 }
374
375 EXPECT_THAT(lstm->GetOutput(),
376 ElementsAreArray(ArrayFloatNear(expected, tolerance)));
377 }
378 }
379
380 // Sets all weights and biases that have been defined by test. The test can
381 // define only a subset of all those vectors, and only the ones that have been
382 // defined will be set.
SetAllWeightsAndBiases(LSTMOpModel * lstm)383 void SetAllWeightsAndBiases(LSTMOpModel* lstm) {
384 lstm->SetInputToInputWeights(input_to_input_weights_);
385 lstm->SetInputToCellWeights(input_to_cell_weights_);
386 lstm->SetInputToForgetWeights(input_to_forget_weights_);
387 lstm->SetInputToOutputWeights(input_to_output_weights_);
388
389 lstm->SetInputGateBias(input_gate_bias_);
390 lstm->SetCellBias(cell_gate_bias_);
391 lstm->SetForgetGateBias(forget_gate_bias_);
392 lstm->SetOutputGateBias(output_gate_bias_);
393
394 lstm->SetRecurrentToInputWeights(recurrent_to_input_weights_);
395 lstm->SetRecurrentToCellWeights(recurrent_to_cell_weights_);
396 lstm->SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
397 lstm->SetRecurrentToOutputWeights(recurrent_to_output_weights_);
398
399 lstm->SetCellToInputWeights(cell_to_input_weights_);
400 lstm->SetCellToForgetWeights(cell_to_forget_weights_);
401 lstm->SetCellToOutputWeights(cell_to_output_weights_);
402
403 lstm->SetProjectionWeights(projection_weights_);
404
405 lstm->SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
406 lstm->SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
407 lstm->SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
408 lstm->SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
409 }
410 };
411
TEST_P(LstmOpTest,NoCifg_NoPeephole_NoProjection_NoLayerNorm)412 TEST_P(LstmOpTest, NoCifg_NoPeephole_NoProjection_NoLayerNorm) {
413 const int n_batch = 1;
414 const int n_input = 2;
415 // n_cell and n_output have the same size when there is no projection.
416 const int n_cell = 4;
417 const int n_output = 4;
418
419 TensorType weight_type;
420 bool model_has_legacy_20_inputs;
421 bool asymmetric_quantize_inputs;
422 std::tie(weight_type, model_has_legacy_20_inputs,
423 asymmetric_quantize_inputs) = GetParam();
424
425 // TODO(b/158205028): Fix this test if using NN-API.
426 if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
427 return;
428 }
429
430 input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, -0.34550029,
431 0.04266912, -0.15680569, -0.34856534, 0.43890524};
432 input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
433 -0.20583314, 0.44344562, 0.22077113, -0.29909778};
434 input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
435 -0.31343272, -0.40032279, 0.44781327,
436 0.01387155, -0.35593212};
437 input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, 0.40525138,
438 0.44272184, 0.03897077, -0.1556896, 0.19487578};
439 input_gate_bias_ = {0., 0., 0., 0.};
440 cell_gate_bias_ = {0., 0., 0., 0.};
441 forget_gate_bias_ = {1., 1., 1., 1.};
442 output_gate_bias_ = {0., 0., 0., 0.};
443
444 recurrent_to_input_weights_ = {
445 -0.0063535, -0.2042388, 0.31454784, -0.35746509,
446 0.28902304, 0.08183324, -0.16555229, 0.02286911,
447 -0.13566875, 0.03034258, 0.48091322, -0.12528998,
448 0.24077177, -0.51332325, -0.33502164, 0.10629296};
449
450 recurrent_to_cell_weights_ = {
451 -0.3407414, 0.24443203, -0.2078532, 0.26320225,
452 0.05695659, -0.00123841, -0.4744786, -0.35869038,
453 -0.06418842, -0.13502428, -0.501764, 0.22830659,
454 -0.46367589, 0.26016325, -0.03894562, -0.16368064};
455
456 recurrent_to_forget_weights_ = {
457 -0.48684245, -0.06655136, 0.42224967, 0.2112639,
458 0.27654213, 0.20864892, -0.07646349, 0.45877004,
459 0.00141793, -0.14609534, 0.36447752, 0.09196436,
460 0.28053468, 0.01560611, -0.20127171, -0.01140004};
461
462 recurrent_to_output_weights_ = {
463 0.43385774, -0.17194885, 0.2718237, 0.09215671,
464 0.24107647, -0.39835793, 0.18212086, 0.01301402,
465 0.48572797, -0.50656658, 0.20047462, -0.20607421,
466 -0.51818722, -0.15390486, 0.0468148, 0.39922136};
467
468 // num_steps * num_batch * num_inputs
469 lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}};
470 // num_steps * num_batch * num_outputs
471 lstm_golden_output_ = {{{-0.02973187, 0.1229473, 0.20885126, -0.15358765}},
472 {{-0.03716109, 0.12507336, 0.41193449, -0.20860538}},
473 {{-0.15053082, 0.09120187, 0.24278517, -0.12222792}}};
474
475 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
476 /*use_cifg=*/false, /*use_peephole=*/false,
477 /*use_projection_weights=*/false,
478 /*use_projection_bias=*/false, weight_type,
479 model_has_legacy_20_inputs,
480 /*is_layer_norm=*/false, asymmetric_quantize_inputs);
481
482 static const auto* tolerance_per_type =
483 new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
484 {TensorType_UINT8, 0.0157651f},
485 {TensorType_INT8, 0.0157651f}};
486 VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
487 }
488
TEST_P(LstmOpTest,Cifg_Peephole_NoProjection_NoLayerNorm)489 TEST_P(LstmOpTest, Cifg_Peephole_NoProjection_NoLayerNorm) {
490 const int n_batch = 1;
491 const int n_input = 2;
492 // n_cell and n_output have the same size when there is no projection.
493 const int n_cell = 4;
494 const int n_output = 4;
495
496 TensorType weight_type;
497 bool model_has_legacy_20_inputs;
498 bool asymmetric_quantize_inputs;
499 std::tie(weight_type, model_has_legacy_20_inputs,
500 asymmetric_quantize_inputs) = GetParam();
501
502 // TODO(b/158205028): Fix this test if using NN-API.
503 if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
504 return;
505 }
506
507 input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, 0.05100781,
508 0.04717243, 0.48944736, -0.38535351, -0.17212132};
509
510 input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, -0.3633365,
511 -0.22755712, 0.28253698, 0.24407166, 0.33826375};
512
513 input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
514 -0.09426838, -0.44257352, 0.54939759,
515 0.01533556, 0.42751634};
516 cell_gate_bias_ = {0., 0., 0., 0.};
517 forget_gate_bias_ = {1., 1., 1., 1.};
518 output_gate_bias_ = {0., 0., 0., 0.};
519
520 recurrent_to_cell_weights_ = {
521 0.54066205, -0.32668582, -0.43562764, -0.56094903,
522 0.42957711, 0.01841056, -0.32764608, -0.33027974,
523 -0.10826075, 0.20675004, 0.19069612, -0.03026325,
524 -0.54532051, 0.33003211, 0.44901288, 0.21193194};
525
526 recurrent_to_forget_weights_ = {
527 -0.13832897, -0.0515101, -0.2359007, -0.16661474,
528 -0.14340827, 0.36986142, 0.23414481, 0.55899,
529 0.10798943, -0.41174671, 0.17751795, -0.34484994,
530 -0.35874045, -0.11352962, 0.27268326, 0.54058349};
531
532 recurrent_to_output_weights_ = {
533 0.41613156, 0.42610586, -0.16495961, -0.5663873,
534 0.30579174, -0.05115908, -0.33941799, 0.23364776,
535 0.11178309, 0.09481031, -0.26424935, 0.46261835,
536 0.50248802, 0.26114327, -0.43736315, 0.33149987};
537
538 cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, 0.31544167};
539 cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, -0.77109635};
540
541 lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}};
542 lstm_golden_output_ = {{{-0.36444446, -0.00352185, 0.12886585, -0.05163646}},
543 {{-0.42312205, -0.01218222, 0.24201041, -0.08124574}},
544 {{-0.358325, -0.04621704, 0.21641694, -0.06471302}}};
545
546 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
547 /*use_cifg=*/true, /*use_peephole=*/true,
548 /*use_projection_weights=*/false,
549 /*use_projection_bias=*/false, weight_type,
550 model_has_legacy_20_inputs, /*is_layer_norm=*/false,
551 asymmetric_quantize_inputs);
552
553 static const auto* tolerance_per_type =
554 new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
555 {TensorType_UINT8, 0.03573f},
556 {TensorType_INT8, 0.03573f}};
557 VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
558 }
559
TEST_P(LstmOpTest,NoCifg_Peephole_Projection_NoLayerNorm)560 TEST_P(LstmOpTest, NoCifg_Peephole_Projection_NoLayerNorm) {
561 const int n_batch = 2;
562 const int n_input = 5;
563 const int n_cell = 20;
564 const int n_output = 16;
565
566 TensorType weight_type;
567 bool model_has_legacy_20_inputs;
568 bool asymmetric_quantize_inputs;
569 std::tie(weight_type, model_has_legacy_20_inputs,
570 asymmetric_quantize_inputs) = GetParam();
571
572 // TODO(b/158205028): Fix this test if using NN-API.
573 if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
574 return;
575 }
576
577 input_to_input_weights_ = {
578 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
579 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
580 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
581 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
582 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
583 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
584 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
585 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
586 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
587 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
588 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
589 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
590 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
591 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
592 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
593 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
594 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
595 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
596 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
597 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
598
599 input_to_forget_weights_ = {
600 -0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
601 -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
602 -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
603 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
604 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
605 -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
606 -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
607 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
608 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
609 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
610 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
611 -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
612 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
613 -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
614 -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
615 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
616 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
617 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
618 -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
619 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
620
621 input_to_cell_weights_ = {
622 -0.04580283, -0.09549462, -0.032418985, -0.06454633, -0.043528453,
623 0.043018587, -0.049152344, -0.12418144, -0.078985475, -0.07596889,
624 0.019484362, -0.11434962, -0.0074034138, -0.06314844, -0.092981495,
625 0.0062155537, -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
626 0.10665918, -0.032036792, -0.08505916, -0.10843358, -0.13002433,
627 -0.036816437, -0.02130134, -0.016518239, 0.0047691227, -0.0025825808,
628 0.066017866, 0.029991534, -0.10652836, -0.1037554, -0.13056071,
629 -0.03266643, -0.033702414, -0.006473424, -0.04611692, 0.014419339,
630 -0.025174323, 0.0396852, 0.081777506, 0.06157468, 0.10210095,
631 -0.009658194, 0.046511717, 0.03603906, 0.0069369148, 0.015960095,
632 -0.06507666, 0.09551598, 0.053568836, 0.06408714, 0.12835667,
633 -0.008714329, -0.20211966, -0.12093674, 0.029450472, 0.2849013,
634 -0.029227901, 0.1164364, -0.08560263, 0.09941786, -0.036999565,
635 -0.028842626, -0.0033637602, -0.017012902, -0.09720865, -0.11193351,
636 -0.029155117, -0.017936034, -0.009768936, -0.04223324, -0.036159635,
637 0.06505112, -0.021742892, -0.023377212, -0.07221364, -0.06430552,
638 0.05453865, 0.091149814, 0.06387331, 0.007518393, 0.055960953,
639 0.069779344, 0.046411168, 0.10509911, 0.07463894, 0.0075130584,
640 0.012850982, 0.04555431, 0.056955688, 0.06555285, 0.050801456,
641 -0.009862683, 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
642
643 input_to_output_weights_ = {
644 -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
645 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
646 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
647 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
648 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
649 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
650 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
651 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
652 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
653 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
654 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
655 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
656 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
657 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
658 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
659 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
660 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
661 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
662 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
663 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
664
665 input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
666 0.053110216, -0.06928846, -0.13942584, -0.11816189,
667 0.19483899, 0.03652339, -0.10250295, 0.036714908,
668 -0.18426876, 0.036065217, 0.21810818, 0.02383196,
669 -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
670
671 forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
672 0.11098921, 0.15378423, 0.09263801, 0.09790885,
673 0.09508917, 0.061199076, 0.07665568, -0.015443159,
674 -0.03499149, 0.046190713, 0.08895977, 0.10899629,
675 0.40694186, 0.06030037, 0.012413437, -0.06108739};
676
677 cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
678 -0.1483596, -0.10639995, -0.091433935, 0.058573797,
679 -0.06809782, -0.07889636, -0.043246906, -0.09829136,
680 -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
681 0.016178843, 0.1749513, 0.13975595, 0.92058027};
682
683 output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
684 0.027195795, 0.35373217, -0.018957434, 0.008907322,
685 -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
686 0.040952638, 0.3147856, 0.08225149, -0.057416286,
687 -0.14995944, -0.008040261, 0.13208859, 0.029760877};
688
689 recurrent_to_input_weights_ = {
690 -0.001374326, -0.078856036, 0.10672688, 0.029162422,
691 -0.11585556, 0.02557986, -0.13446963, -0.035785314,
692 -0.01244275, 0.025961924, -0.02337298, -0.044228926,
693 -0.055839065, -0.046598054, -0.010546039, -0.06900766,
694 0.027239809, 0.022582639, -0.013296484, -0.05459212,
695 0.08981, -0.045407712, 0.08682226, -0.06867011,
696 -0.14390695, -0.02916037, 0.000996957, 0.091420636,
697 0.14283475, -0.07390571, -0.06402044, 0.062524505,
698 -0.093129106, 0.04860203, -0.08364217, -0.08119002,
699 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
700 -0.13732095, 0.012405723, -0.07551853, 0.06343048,
701 0.12162708, -0.031923793, -0.014335606, 0.01790974,
702 -0.10650317, -0.0724401, 0.08554849, -0.05727212,
703 0.06556731, -0.042729504, -0.043227166, 0.011683251,
704 -0.013082158, -0.029302018, -0.010899579, -0.062036745,
705 -0.022509435, -0.00964907, -0.01567329, 0.04260106,
706 -0.07787477, -0.11576462, 0.017356863, 0.048673786,
707 -0.017577527, -0.05527947, -0.082487635, -0.040137455,
708 -0.10820036, -0.04666372, 0.022746278, -0.07851417,
709 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
710 0.08944216, -0.0685835, 0.010513544, 0.07228705,
711 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
712 0.040414046, -0.1380399, 0.094208956, -0.05722982,
713 0.012092817, -0.04989123, -0.086576, -0.003399834,
714 -0.04696032, -0.045747425, 0.10091314, 0.048676282,
715 -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
716 0.09504992, 0.041799378, -0.049185462, -0.031518843,
717 -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
718 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
719 -0.10167381, 0.042500053, -0.01447153, 0.06464186,
720 -0.017142897, 0.03312627, 0.009205989, 0.024138335,
721 -0.011337001, 0.035530265, -0.010912711, 0.0706555,
722 -0.005894094, 0.051841937, -0.1401738, -0.02351249,
723 0.0365468, 0.07590991, 0.08838724, 0.021681072,
724 -0.10086113, 0.019608743, -0.06195883, 0.077335775,
725 0.023646897, -0.095322326, 0.02233014, 0.09756986,
726 -0.048691444, -0.009579111, 0.07595467, 0.11480546,
727 -0.09801813, 0.019894179, 0.08502348, 0.004032281,
728 0.037211012, 0.068537936, -0.048005626, -0.091520436,
729 -0.028379958, -0.01556313, 0.06554592, -0.045599163,
730 -0.01672207, -0.020169014, -0.011877351, -0.20212261,
731 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
732 -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
733 0.015963363, 0.00871737, 0.060130805, 0.028611384,
734 0.10109069, -0.015060172, -0.07894427, 0.06401885,
735 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
736 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
737 0.019899689, 0.006106124, -0.027092824, 0.0786356,
738 0.05052217, -0.058925, -0.011402121, -0.024987547,
739 -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
740 -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
741 -0.033664223, -0.07978348, -0.025200296, -0.017207067,
742 -0.058403496, -0.055697463, 0.005798788, 0.12965427,
743 -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
744 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
745 0.013806405, -0.017858358, -0.01008298, -0.07700066,
746 -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
747 0.062634714, -0.02338735, -0.039547626, -0.02050681,
748 0.03385117, -0.083611414, 0.002862572, -0.09421313,
749 0.058618143, -0.08598433, 0.00972939, 0.023867095,
750 -0.053934585, -0.023203006, 0.07452513, -0.048767887,
751 -0.07314807, -0.056307215, -0.10433547, -0.06440842,
752 0.04328182, 0.04389765, -0.020006588, -0.09076438,
753 -0.11652589, -0.021705797, 0.03345259, -0.010329105,
754 -0.025767034, 0.013057034, -0.07316461, -0.10145612,
755 0.06358255, 0.18531723, 0.07759293, 0.12006465,
756 0.1305557, 0.058638252, -0.03393652, 0.09622831,
757 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
758 -0.005644518, 0.06857898, -0.12598175, -0.035084512,
759 0.03156317, -0.12794146, -0.031963028, 0.04692781,
760 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
761 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
762 0.08184801, -0.019164372, 0.06791302, 0.034257166,
763 -0.10307039, 0.021943003, 0.046745934, 0.0790918,
764 -0.0265588, -0.007824208, 0.042546265, -0.00977924,
765 -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
766 -0.014512694, -0.08251313, 0.08861942, 0.13589665,
767 0.026351685, 0.012641483, 0.07466548, 0.044301085,
768 -0.045414884, -0.051112458, 0.03444247, -0.08502782,
769 -0.04106223, -0.028126027, 0.028473156, 0.10467447};
770
771 recurrent_to_cell_weights_ = {
772 -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
773 0.055647098, -0.05713207, -0.05626563, 0.005559383,
774 0.03375411, -0.025757805, -0.088049285, 0.06017052,
775 -0.06570978, 0.007384076, 0.035123326, -0.07920549,
776 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
777 0.08089997, 0.05143358, 0.038261272, 0.03339287,
778 -0.027673481, 0.044746667, 0.028349208, 0.020090483,
779 -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
780 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
781 -0.10893326, 0.076739706, -0.08509834, -0.027997585,
782 0.037871376, 0.01449768, -0.09002357, -0.06111149,
783 -0.046195522, 0.0422062, -0.005683705, -0.1253618,
784 -0.012925729, -0.04890792, 0.06985068, 0.037654128,
785 0.03398274, -0.004781977, 0.007032333, -0.031787455,
786 0.010868644, -0.031489216, 0.09525667, 0.013939797,
787 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
788 -0.048885044, -0.12722108, 0.035304096, 0.06554885,
789 0.00972396, -0.039238118, -0.05159735, -0.11329045,
790 0.1613692, -0.03750952, 0.06529313, -0.071974665,
791 -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
792 0.02786344, -0.014179351, 0.005264273, 0.14376344,
793 0.015983658, 0.03406988, -0.06939408, 0.040699873,
794 0.02111075, 0.09669095, 0.041345075, -0.08316494,
795 -0.07684199, -0.045768797, 0.032298047, -0.041805092,
796 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
797 -0.024950314, 0.11574242, 0.04508852, -0.04335324,
798 0.06760663, -0.027437469, 0.07216407, 0.06977076,
799 -0.05438599, 0.034033038, -0.028602652, 0.05346137,
800 0.043184172, -0.037189785, 0.10420091, 0.00882477,
801 -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
802 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
803 0.04361412, -0.007001822, 0.09631092, -0.06702025,
804 -0.042049985, -0.035070654, -0.04103342, -0.10273396,
805 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
806 -0.008264958, 0.042035464, 0.05891794, 0.029673764,
807 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
808 -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
809 -0.04043371, -0.017094059, 0.07229206, -0.023670016,
810 -0.052195564, -0.025616996, -0.01520939, 0.045104615,
811 -0.007376126, 0.003533447, 0.006570588, 0.056037236,
812 0.12436656, 0.051817212, 0.028532185, -0.08686856,
813 0.11868599, 0.07663395, -0.07323171, 0.03463402,
814 -0.050708205, -0.04458982, -0.11590894, 0.021273347,
815 0.1251325, -0.15313013, -0.12224372, 0.17228661,
816 0.023029093, 0.086124025, 0.006445803, -0.03496501,
817 0.028332196, 0.04449512, -0.042436164, -0.026587414,
818 -0.006041347, -0.09292539, -0.05678812, 0.03897832,
819 0.09465633, 0.008115513, -0.02171956, 0.08304309,
820 0.071401566, 0.019622514, 0.032163795, -0.004167056,
821 0.02295182, 0.030739572, 0.056506045, 0.004612461,
822 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
823 -0.1335546, -0.030136576, 0.11584653, -0.014678886,
824 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
825 -0.0329582, 0.07922767, 0.029322514, 0.026405897,
826 0.04207835, -0.07073373, 0.063781224, 0.0859677,
827 -0.10925287, -0.07011058, 0.048005477, 0.03438226,
828 -0.09606514, -0.006669445, -0.043381985, 0.04240257,
829 -0.06955775, -0.06769346, 0.043903265, -0.026784198,
830 -0.017840602, 0.024307009, -0.040079936, -0.019946516,
831 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
832 0.15978073, 0.10185836, 0.10298046, -0.015476589,
833 -0.039390966, -0.072174534, 0.0739445, -0.1211869,
834 -0.0347889, -0.07943156, 0.014809798, -0.12412325,
835 -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
836 -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
837 -0.01514876, -0.056505352, -0.012800942, -0.06994386,
838 0.012962922, -0.031234352, 0.07029052, 0.016418684,
839 0.03618972, 0.055686004, -0.08663945, -0.017404709,
840 -0.054761406, 0.029065743, 0.052404847, 0.020238016,
841 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
842 0.06262858, 0.009184685, 0.020785125, -0.043904778,
843 -0.0270329, -0.03299152, -0.060088247, -0.015162964,
844 -0.001828936, 0.12642565, -0.056757294, 0.013586685,
845 0.09232601, -0.035886683, 0.06000002, 0.05229691,
846 -0.052580316, -0.082029596, -0.010794592, 0.012947712,
847 -0.036429964, -0.085508935, -0.13127148, -0.017744139,
848 0.031502828, 0.036232427, -0.031581745, 0.023051167,
849 -0.05325106, -0.03421577, 0.028793324, -0.034633752,
850 -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
851 -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
852
853 recurrent_to_forget_weights_ = {
854 -0.057784554, -0.026057621, -0.068447545, -0.022581743,
855 0.14811787, 0.10826372, 0.09471067, 0.03987225,
856 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
857 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
858 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
859 -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
860 -0.06193199, 0.055729095, 0.03736828, 0.020123724,
861 0.061878487, -0.04729229, 0.034919553, -0.07585433,
862 -0.04421272, -0.044019096, 0.085488975, 0.04058006,
863 -0.06890133, -0.030951202, -0.024628663, -0.07672815,
864 0.034293607, 0.08556707, -0.05293577, -0.033561368,
865 -0.04899627, 0.0241671, 0.015736353, -0.095442444,
866 -0.029564252, 0.016493602, -0.035026584, 0.022337519,
867 -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
868 0.016435321, -0.03263031, -0.09543275, -0.047392778,
869 0.013454138, 0.028934088, 0.01685226, -0.086110644,
870 -0.046250615, -0.01847454, 0.047608484, 0.07339695,
871 0.034546845, -0.04881143, 0.009128804, -0.08802852,
872 0.03761666, 0.008096139, -0.014454086, 0.014361001,
873 -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
874 -0.06509276, -0.006021153, -0.08570962, -0.1451793,
875 0.060212336, 0.055259194, 0.06974018, 0.049454916,
876 -0.027794661, -0.08077226, -0.016179763, 0.1169753,
877 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
878 -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
879 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
880 -0.05695512, 0.047233116, 0.038937137, -0.06542224,
881 0.014429736, -0.09719407, 0.13908425, -0.05379757,
882 0.012321099, 0.082840554, -0.029899208, 0.044217527,
883 0.059855383, 0.07711018, -0.045319796, 0.0948846,
884 -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
885 -0.13873616, 0.040668588, 0.034832682, -0.015319203,
886 -0.018715994, 0.046002675, 0.0599172, -0.043107376,
887 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
888 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
889 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
890 0.052958444, 0.07558703, 0.04817258, 0.044462286,
891 -0.015213451, -0.08783778, -0.0561384, -0.003008196,
892 0.047060397, -0.002058388, 0.03429439, -0.018839769,
893 0.024734668, 0.024614193, -0.042046934, 0.09597743,
894 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
895 -0.02558259, -0.022822596, -0.023273505, -0.02464396,
896 -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
897 0.04383914, -0.046476185, 0.028658995, 0.060410924,
898 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
899 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
900 0.015898481, 0.021362653, -0.030262267, 0.016587038,
901 -0.011442813, 0.041154444, -0.007631438, -0.03423484,
902 -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
903 0.02318443, -0.041350313, 0.021485701, -0.10906167,
904 -0.028218046, -0.00954771, 0.020531068, -0.11995105,
905 -0.03672871, 0.024019798, 0.014255957, -0.05221243,
906 -0.00661567, -0.04630967, 0.033188973, 0.10107534,
907 -0.014027541, 0.030796422, -0.10270911, -0.035999842,
908 0.15443139, 0.07684145, 0.036571592, -0.035900835,
909 -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
910 -0.03858649, 0.01849943, 0.13872518, 0.01503974,
911 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
912 -0.047401894, 0.03100163, -0.041533746, -0.10430945,
913 0.044574402, -0.01425562, -0.024290353, 0.034563623,
914 0.05866852, 0.023947537, -0.09445152, 0.035450947,
915 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
916 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
917 0.03532124, -0.016341697, 0.09685456, -0.016764693,
918 0.051808182, 0.05875331, -0.04536488, 0.001626336,
919 -0.028892258, -0.01048663, -0.009793449, -0.017093895,
920 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
921 -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
922 -0.01769146, 0.040995963, 0.02235177, -0.060430344,
923 0.11475477, -0.023854522, 0.10071741, 0.0686208,
924 -0.014250481, 0.034261297, 0.047418304, 0.08562733,
925 -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
926 0.04096551, 0.032249358, -0.08355519, -0.026823482,
927 0.056386515, -0.010401743, -0.028396193, 0.08507674,
928 0.014410365, 0.020995233, 0.17040324, 0.11511526,
929 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
930 -0.081302024, 0.017264642, -0.009585969, 0.09491168,
931 -0.051313367, 0.054532815, -0.014298593, 0.10657464,
932 0.007076659, 0.10964551, 0.0409152, 0.008275321,
933 -0.07283536, 0.07937492, 0.04192024, -0.1075027};
934
935 recurrent_to_output_weights_ = {
936 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
937 -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
938 -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
939 -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
940 -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
941 -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
942 -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
943 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
944 -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
945 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
946 -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
947 -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
948 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
949 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
950 -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
951 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
952 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
953 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
954 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
955 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
956 -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
957 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
958 -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
959 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
960 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
961 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
962 -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
963 -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
964 -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
965 -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
966 -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
967 -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
968 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
969 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
970 -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
971 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
972 -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
973 -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
974 -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
975 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
976 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
977 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
978 -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
979 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
980 -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
981 -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
982 -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
983 -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
984 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
985 -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
986 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
987 -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
988 -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
989 -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
990 -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
991 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
992 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
993 -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
994 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
995 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
996 -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
997 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
998 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
999 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
1000 };
1001
1002 cell_to_input_weights_ = {
1003 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
1004 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
1005 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
1006 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
1007
1008 cell_to_forget_weights_ = {
1009 -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
1010 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
1011 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
1012 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
1013
1014 cell_to_output_weights_ = {0.08286371, -0.08261836, -0.51210177, 0.002913762,
1015 0.17764764, -0.5495371, -0.08460716, -0.24552552,
1016 0.030037103, 0.04123544, -0.11940523, 0.007358328,
1017 0.1890978, 0.4833202, -0.34441817, 0.36312827,
1018 -0.26375428, 0.1457655, -0.19724406, 0.15548733};
1019
1020 projection_weights_ = {
1021 -0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
1022 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
1023 -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
1024 -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
1025 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
1026 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
1027 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
1028 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
1029 -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
1030 -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
1031 -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
1032 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
1033 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
1034 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
1035 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
1036 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
1037 -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
1038 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
1039 -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
1040 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
1041 -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
1042 -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
1043 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
1044 -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
1045 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
1046 -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
1047 -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
1048 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
1049 -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
1050 -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
1051 -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
1052 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
1053 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
1054 -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
1055 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
1056 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
1057 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
1058 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
1059 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
1060 -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
1061 -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
1062 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
1063 -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
1064 -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
1065 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
1066 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
1067 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
1068 -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
1069 -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
1070 -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
1071 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
1072 -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
1073 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
1074 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
1075 -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
1076 -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
1077 -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
1078 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
1079 -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
1080 -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
1081 -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
1082 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
1083 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
1084 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656};
1085
1086 lstm_input_ = {// Step 1
1087 {{0.787926, 0.151646, 0.071352, 0.118426, 0.458058},
1088 {0.295743, 0.544053, 0.690064, 0.858138, 0.497181}},
1089 // Step 2
1090 {{0.596268, 0.998386, 0.568695, 0.864524, 0.571277},
1091 {0.642421, 0.524260, 0.134799, 0.003639, 0.162482}},
1092 // Step 3
1093 {{0.073204, 0.296072, 0.743333, 0.069199, 0.045348},
1094 {0.640394, 0.930399, 0.050782, 0.432485, 0.988078}},
1095 // Step 4
1096 {{0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
1097 {0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}};
1098
1099 lstm_golden_output_ = {
1100 {{-0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, -0.0211779,
1101 0.0283512, -0.0114597, 0.00907307, -0.0244004, -0.0152191, -0.0259063,
1102 0.00914318, 0.00415118, 0.017147, 0.0134203},
1103 {-0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, -0.0186926,
1104 0.0193662, -0.0115437, 0.00422612, -0.0345232, 0.00223253, -0.00957321,
1105 0.0210624, 0.013331, 0.0150954, 0.02168}},
1106
1107 {{-0.0166936, 0.0381209, 0.000889694, 0.0143363, -0.0328911, -0.0234288,
1108 0.0333051, -0.012229, 0.0110322, -0.0457725, -0.000832209, -0.0202817,
1109 0.0327257, 0.0121308, 0.0155969, 0.0312091},
1110 {-0.0141913, 0.0322082, 0.00227024, 0.0260507, -0.0188721, -0.0296489,
1111 0.0399134, -0.0160509, 0.0116039, -0.0447318, -0.0150515, -0.0277406,
1112 0.0316596, 0.0118233, 0.0214762, 0.0293641}},
1113
1114 {{-0.0213783, 0.0350169, 0.000324794, 0.0276012, -0.0263374, -0.0371449,
1115 0.0446149, -0.0205474, 0.0103729, -0.0576349, -0.0150052, -0.0292043,
1116 0.0376827, 0.0136115, 0.0243435, 0.0354492},
1117 {-0.0204549, 0.0450315, -0.00117378, 0.0167673, -0.0375007, -0.0238314,
1118 0.038784, -0.0174034, 0.0131743, -0.0506589, -0.0048447, -0.0240239,
1119 0.0325789, 0.00790065, 0.0220157, 0.0333314}},
1120
1121 {{-0.0189322, 0.0464512, -0.00251373, 0.0225745, -0.0308346, -0.0317124,
1122 0.0460407, -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
1123 0.0286833, 0.00824207, 0.0264887, 0.0305169},
1124 {-0.0264787, 0.0387855, -0.000764675, 0.0217599, -0.037537, -0.0335206,
1125 0.0431679, -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
1126 0.0412031, 0.0118723, 0.0239643, 0.0394009}}};
1127
1128 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1129 /*use_cifg=*/false, /*use_peephole=*/true,
1130 /*use_projection_weights=*/true,
1131 /*use_projection_bias=*/false, weight_type,
1132 model_has_legacy_20_inputs, /*is_layer_norm=*/false,
1133 asymmetric_quantize_inputs);
1134
1135 static const auto* tolerance_per_type = new std::map<TensorType, float>{
1136 {TensorType_FLOAT32, 0.00001f},
1137 {TensorType_UINT8, 0.00467f},
1138 {TensorType_INT8, 0.0015f},
1139 };
1140 VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1141 }
1142
TEST_P(LstmOpTest,NoCifg_Peephole_Projection_LayerNorm)1143 TEST_P(LstmOpTest, NoCifg_Peephole_Projection_LayerNorm) {
1144 const int n_batch = 2;
1145 const int n_input = 5;
1146 const int n_cell = 4;
1147 const int n_output = 3;
1148
1149 TensorType weight_type;
1150 // Layer normalization needs 24 inputs.
1151 bool asymmetric_quantize_inputs;
1152 std::tie(weight_type, std::ignore, asymmetric_quantize_inputs) = GetParam();
1153
1154 // TODO(b/158205028): Fix this test if using NN-API.
1155 if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
1156 return;
1157 }
1158
1159 input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
1160 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
1161 -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1162
1163 input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
1164 -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
1165 -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
1166
1167 input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
1168 -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
1169 -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
1170
1171 input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
1172 -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
1173 -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
1174
1175 input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
1176
1177 forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
1178
1179 cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
1180
1181 output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
1182
1183 recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
1184 -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1185
1186 recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
1187 -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1188
1189 recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
1190 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1191
1192 recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
1193 -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1194
1195 cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
1196
1197 cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
1198
1199 cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
1200
1201 input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
1202 forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
1203 cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
1204 output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
1205
1206 projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
1207 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1208
1209 lstm_input_ = {
1210 {{0.7, 0.8, 0.1, 0.2, 0.3}, {0.3, 0.2, 0.9, 0.8, 0.1}},
1211
1212 {{0.8, 0.1, 0.2, 0.4, 0.5}, {0.1, 0.5, 0.2, 0.4, 0.2}},
1213
1214 {{0.2, 0.7, 0.7, 0.1, 0.7}, {0.6, 0.9, 0.2, 0.5, 0.7}},
1215 };
1216
1217 lstm_golden_output_ = {
1218 {{0.0244077, 0.128027, -0.00170918}, {-0.00692428, 0.0848741, 0.063445}},
1219
1220 {{0.0137642, 0.140751, 0.0395835}, {-0.00403912, 0.139963, 0.072681}},
1221
1222 {{-0.00459231, 0.155278, 0.0837377}, {0.00752706, 0.161903, 0.0561371}}};
1223
1224 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1225 /*use_cifg=*/false, /*use_peephole=*/true,
1226 /*use_projection_weights=*/true,
1227 /*use_projection_bias=*/false, weight_type,
1228 /*model_has_legacy_20_inputs=*/false,
1229 /*is_layer_norm=*/true, asymmetric_quantize_inputs);
1230
1231 static const auto* tolerance_per_type =
1232 new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
1233 {TensorType_UINT8, 0.0010907f},
1234 {TensorType_INT8, 0.00106f}};
1235 VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1236 }
1237
TEST_P(LstmOpTest,Cifg_Peephole_Projection_LayerNorm)1238 TEST_P(LstmOpTest, Cifg_Peephole_Projection_LayerNorm) {
1239 const int n_batch = 2;
1240 const int n_input = 5;
1241 const int n_cell = 4;
1242 const int n_output = 3;
1243
1244 TensorType weight_type;
1245 // Layer normalization needs 24 inputs.
1246 bool asymmetric_quantize_inputs;
1247 std::tie(weight_type, std::ignore, asymmetric_quantize_inputs) = GetParam();
1248
1249 // TODO(b/158205028): Fix this test if using NN-API.
1250 if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
1251 return;
1252 }
1253
1254 input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
1255 -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
1256 -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
1257 input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
1258 -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
1259 -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
1260 input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
1261 -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
1262 -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
1263
1264 forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
1265 cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
1266 output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
1267
1268 recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
1269 -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1270 recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
1271 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1272 recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
1273 -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1274
1275 cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
1276 cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
1277
1278 forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
1279 cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
1280 output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
1281 projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
1282 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1283
1284 lstm_input_ = {{{0.7, 0.8, 0.1, 0.2, 0.3}, {0.3, 0.2, 0.9, 0.8, 0.1}},
1285
1286 {{0.8, 0.1, 0.2, 0.4, 0.5}, {0.1, 0.5, 0.2, 0.4, 0.2}},
1287
1288 {{0.2, 0.7, 0.7, 0.1, 0.7}, {0.6, 0.9, 0.2, 0.5, 0.7}}};
1289 lstm_golden_output_ = {{{0.02129706, 0.140816242, 0.0112733059},
1290 {-0.0226350538, 0.0916948169, 0.0769175813}},
1291
1292 {{0.0132302344, 0.152308047, 0.0346313119},
1293 {-0.0269966982, 0.149707705, 0.094149217}},
1294
1295 {{-0.0123688057, 0.165790111, 0.0893077999},
1296 {-0.0103429332, 0.173016444, 0.0720508844}}};
1297
1298 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1299 /*use_cifg=*/true, /*use_peephole=*/true,
1300 /*use_projection_weights=*/true,
1301 /*use_projection_bias=*/false, weight_type,
1302 /*model_has_legacy_20_inputs=*/false,
1303 /*is_layer_norm=*/true, asymmetric_quantize_inputs);
1304
1305 static const auto* tolerance_per_type =
1306 new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
1307 {TensorType_UINT8, 0.000971057f},
1308 {TensorType_INT8, 0.001f}};
1309 VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1310 }
1311
1312 class LSTMIntegerOpModel : public SingleOpModel {
1313 public:
LSTMIntegerOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,bool use_layer_norm,bool use_8x8_8_implementation,const std::vector<std::pair<float,float>> & ranges,const std::vector<std::pair<float,int>> & intermediates)1314 LSTMIntegerOpModel(int n_batch, int n_input, int n_cell, int n_output,
1315 bool use_cifg, bool use_peephole,
1316 bool use_projection_weights, bool use_projection_bias,
1317 bool use_layer_norm, bool use_8x8_8_implementation,
1318 const std::vector<std::pair<float, float>>& ranges,
1319 const std::vector<std::pair<float, int>>& intermediates)
1320 : n_input_(n_input), n_output_(n_output) {
1321 input_ = AddInput({TensorType_INT8,
1322 {n_batch, n_input},
1323 ranges[0].first,
1324 ranges[0].second});
1325
1326 if (use_cifg) {
1327 input_to_input_weights_ = AddNullInput();
1328 } else {
1329 input_to_input_weights_ = AddInput({TensorType_INT8,
1330 {n_cell, n_input},
1331 ranges[1].first,
1332 ranges[1].second});
1333 }
1334 input_to_forget_weights_ = AddInput({TensorType_INT8,
1335 {n_cell, n_input},
1336 ranges[2].first,
1337 ranges[2].second});
1338 input_to_cell_weights_ = AddInput({TensorType_INT8,
1339 {n_cell, n_input},
1340 ranges[3].first,
1341 ranges[3].second});
1342 input_to_output_weights_ = AddInput({TensorType_INT8,
1343 {n_cell, n_input},
1344 ranges[4].first,
1345 ranges[4].second});
1346
1347 if (use_cifg) {
1348 recurrent_to_input_weights_ = AddNullInput();
1349 } else {
1350 recurrent_to_input_weights_ = AddInput({TensorType_INT8,
1351 {n_cell, n_output},
1352 ranges[5].first,
1353 ranges[5].second});
1354 }
1355 recurrent_to_forget_weights_ = AddInput({TensorType_INT8,
1356 {n_cell, n_output},
1357 ranges[6].first,
1358 ranges[6].second});
1359 recurrent_to_cell_weights_ = AddInput({TensorType_INT8,
1360 {n_cell, n_output},
1361 ranges[7].first,
1362 ranges[7].second});
1363 recurrent_to_output_weights_ = AddInput({TensorType_INT8,
1364 {n_cell, n_output},
1365 ranges[8].first,
1366 ranges[8].second});
1367
1368 if (use_peephole) {
1369 if (use_cifg) {
1370 cell_to_input_weights_ = AddNullInput();
1371 } else {
1372 cell_to_input_weights_ = AddInput(
1373 {TensorType_INT16, {n_cell}, ranges[9].first, ranges[9].second});
1374 }
1375 cell_to_forget_weights_ = AddInput(
1376 {TensorType_INT16, {n_cell}, ranges[10].first, ranges[10].second});
1377 cell_to_output_weights_ = AddInput(
1378 {TensorType_INT16, {n_cell}, ranges[11].first, ranges[11].second});
1379 } else {
1380 cell_to_input_weights_ = AddNullInput();
1381 cell_to_forget_weights_ = AddNullInput();
1382 cell_to_output_weights_ = AddNullInput();
1383 }
1384
1385 if (use_cifg) {
1386 input_gate_bias_ = AddNullInput();
1387 } else {
1388 input_gate_bias_ = AddInput(
1389 {TensorType_INT32, {n_cell}, ranges[12].first, ranges[12].second});
1390 }
1391 forget_gate_bias_ = AddInput(
1392 {TensorType_INT32, {n_cell}, ranges[13].first, ranges[13].second});
1393 cell_gate_bias_ = AddInput(
1394 {TensorType_INT32, {n_cell}, ranges[14].first, ranges[14].second});
1395 output_gate_bias_ = AddInput(
1396 {TensorType_INT32, {n_cell}, ranges[15].first, ranges[15].second});
1397
1398 if (use_projection_weights) {
1399 projection_weights_ = AddInput({TensorType_INT8,
1400 {n_output, n_cell},
1401 ranges[16].first,
1402 ranges[16].second});
1403 } else {
1404 projection_weights_ = AddNullInput();
1405 }
1406 if (use_projection_bias) {
1407 CHECK(use_projection_weights);
1408 projection_bias_ = AddInput(
1409 {TensorType_INT32, {n_output}, ranges[17].first, ranges[17].second});
1410 } else {
1411 projection_bias_ = AddNullInput();
1412 }
1413
1414 // Adding the 2 state tensors.
1415 AddVariableInput({TensorType_INT16,
1416 {n_batch, n_output},
1417 ranges[18].first,
1418 ranges[18].second});
1419 AddVariableInput({TensorType_INT16,
1420 {n_batch, n_cell},
1421 ranges[19].first,
1422 ranges[19].second});
1423
1424 // Layer norm weights.
1425 if (use_layer_norm) {
1426 if (use_cifg) {
1427 input_layer_norm_coefficients_ = AddNullInput();
1428 } else {
1429 input_layer_norm_coefficients_ = AddInput(
1430 {TensorType_INT16, {n_cell}, ranges[20].first, ranges[20].second});
1431 }
1432 forget_layer_norm_coefficients_ = AddInput(
1433 {TensorType_INT16, {n_cell}, ranges[21].first, ranges[21].second});
1434 cell_layer_norm_coefficients_ = AddInput(
1435 {TensorType_INT16, {n_cell}, ranges[22].first, ranges[22].second});
1436 output_layer_norm_coefficients_ = AddInput(
1437 {TensorType_INT16, {n_cell}, ranges[23].first, ranges[23].second});
1438 }
1439
1440 if (use_8x8_8_implementation) {
1441 EXPECT_EQ(intermediates.size(), 12);
1442 } else {
1443 EXPECT_EQ(intermediates.size(), 5);
1444 }
1445 for (int i = 0; i < intermediates.size(); ++i) {
1446 AddIntermediate(TensorType_INT16, {intermediates[i].first},
1447 {intermediates[i].second});
1448 }
1449
1450 output_ = AddOutput({TensorType_INT8,
1451 {n_batch, n_output},
1452 ranges[24].first,
1453 ranges[24].second});
1454
1455 // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
1456 // default 0.
1457 SetBuiltinOp(
1458 BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
1459 CreateLSTMOptions(builder_, ActivationFunctionType_TANH).Union());
1460
1461 BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
1462 /*allow_fp32_relax_to_fp16=*/false,
1463 /*apply_delegate=*/true, /*allocate_and_delegate=*/false);
1464 }
1465
PerformAllocateAndDelegate()1466 void PerformAllocateAndDelegate() { AllocateAndDelegate(true); }
1467
SetInputToInputWeights(const std::vector<float> & f)1468 void SetInputToInputWeights(const std::vector<float>& f) {
1469 QuantizeAndPopulate<int8_t>(input_to_input_weights_, f);
1470 }
1471
SetInputToForgetWeights(const std::vector<float> & f)1472 void SetInputToForgetWeights(const std::vector<float>& f) {
1473 QuantizeAndPopulate<int8_t>(input_to_forget_weights_, f);
1474 }
1475
SetInputToCellWeights(const std::vector<float> & f)1476 void SetInputToCellWeights(const std::vector<float>& f) {
1477 QuantizeAndPopulate<int8_t>(input_to_cell_weights_, f);
1478 }
1479
SetInputToOutputWeights(const std::vector<float> & f)1480 void SetInputToOutputWeights(const std::vector<float>& f) {
1481 QuantizeAndPopulate<int8_t>(input_to_output_weights_, f);
1482 }
1483
SetRecurrentToInputWeights(const std::vector<float> & f)1484 void SetRecurrentToInputWeights(const std::vector<float>& f) {
1485 QuantizeAndPopulate<int8_t>(recurrent_to_input_weights_, f);
1486 }
1487
SetRecurrentToForgetWeights(const std::vector<float> & f)1488 void SetRecurrentToForgetWeights(const std::vector<float>& f) {
1489 QuantizeAndPopulate<int8_t>(recurrent_to_forget_weights_, f);
1490 }
1491
SetRecurrentToCellWeights(const std::vector<float> & f)1492 void SetRecurrentToCellWeights(const std::vector<float>& f) {
1493 QuantizeAndPopulate<int8_t>(recurrent_to_cell_weights_, f);
1494 }
1495
SetRecurrentToOutputWeights(const std::vector<float> & f)1496 void SetRecurrentToOutputWeights(const std::vector<float>& f) {
1497 QuantizeAndPopulate<int8_t>(recurrent_to_output_weights_, f);
1498 }
1499
SetCellToInputWeights(const std::vector<float> & f)1500 void SetCellToInputWeights(const std::vector<float>& f) {
1501 QuantizeAndPopulate<int16_t>(cell_to_input_weights_, f);
1502 }
1503
SetCellToForgetWeights(const std::vector<float> & f)1504 void SetCellToForgetWeights(const std::vector<float>& f) {
1505 QuantizeAndPopulate<int16_t>(cell_to_forget_weights_, f);
1506 }
1507
SetCellToOutputWeights(const std::vector<float> & f)1508 void SetCellToOutputWeights(const std::vector<float>& f) {
1509 QuantizeAndPopulate<int16_t>(cell_to_output_weights_, f);
1510 }
1511
SetInputLayerNormCoefficients(const std::vector<float> & f)1512 void SetInputLayerNormCoefficients(const std::vector<float>& f) {
1513 QuantizeAndPopulate<int16_t>(input_layer_norm_coefficients_, f);
1514 }
1515
SetForgetLayerNormCoefficients(const std::vector<float> & f)1516 void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
1517 QuantizeAndPopulate<int16_t>(forget_layer_norm_coefficients_, f);
1518 }
1519
SetCellLayerNormCoefficients(const std::vector<float> & f)1520 void SetCellLayerNormCoefficients(const std::vector<float>& f) {
1521 QuantizeAndPopulate<int16_t>(cell_layer_norm_coefficients_, f);
1522 }
1523
SetOutputLayerNormCoefficients(const std::vector<float> & f)1524 void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
1525 QuantizeAndPopulate<int16_t>(output_layer_norm_coefficients_, f);
1526 }
1527
SetInputGateBias(const std::vector<float> & f)1528 void SetInputGateBias(const std::vector<float>& f) {
1529 QuantizeAndPopulate<int32_t>(input_gate_bias_, f);
1530 }
1531
SetForgetGateBias(const std::vector<float> & f)1532 void SetForgetGateBias(const std::vector<float>& f) {
1533 QuantizeAndPopulate<int32_t>(forget_gate_bias_, f);
1534 }
1535
SetCellBias(const std::vector<float> & f)1536 void SetCellBias(const std::vector<float>& f) {
1537 QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
1538 }
1539
SetOutputGateBias(const std::vector<float> & f)1540 void SetOutputGateBias(const std::vector<float>& f) {
1541 QuantizeAndPopulate<int32_t>(output_gate_bias_, f);
1542 }
1543
SetProjectionWeights(const std::vector<float> & f)1544 void SetProjectionWeights(const std::vector<float>& f) {
1545 QuantizeAndPopulate<int8_t>(projection_weights_, f);
1546 }
1547
SetProjectionBias(const std::vector<float> & f)1548 void SetProjectionBias(const std::vector<float>& f) {
1549 QuantizeAndPopulate<int32_t>(projection_bias_, f);
1550 }
1551
SetInput(const std::vector<float> & f)1552 void SetInput(const std::vector<float>& f) {
1553 QuantizeAndPopulate<int8_t>(input_, f);
1554 }
1555
GetOutput()1556 std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
1557
num_inputs()1558 int num_inputs() { return n_input_; }
num_outputs()1559 int num_outputs() { return n_output_; }
1560
1561 protected:
1562 int input_;
1563 int input_to_input_weights_;
1564 int input_to_forget_weights_;
1565 int input_to_cell_weights_;
1566 int input_to_output_weights_;
1567
1568 int recurrent_to_input_weights_;
1569 int recurrent_to_forget_weights_;
1570 int recurrent_to_cell_weights_;
1571 int recurrent_to_output_weights_;
1572
1573 int cell_to_input_weights_;
1574 int cell_to_forget_weights_;
1575 int cell_to_output_weights_;
1576
1577 int input_layer_norm_coefficients_;
1578 int forget_layer_norm_coefficients_;
1579 int cell_layer_norm_coefficients_;
1580 int output_layer_norm_coefficients_;
1581
1582 int input_gate_bias_;
1583 int forget_gate_bias_;
1584 int cell_gate_bias_;
1585 int output_gate_bias_;
1586
1587 int projection_weights_;
1588 int projection_bias_;
1589
1590 int output_;
1591
1592 int n_input_;
1593 int n_output_;
1594 };
1595
TEST(IntegerLstmOpTest,NoCifg_NoPeephole_Projection_LayerNorm)1596 TEST(IntegerLstmOpTest, NoCifg_NoPeephole_Projection_LayerNorm) {
1597 // Hyper parameters.
1598 const int n_batch = 2;
1599 const int n_input = 5;
1600 const int n_cell = 4;
1601 const int n_output = 3;
1602
1603 // Model related weights.
1604 const std::vector<float> input_to_input_weights = {
1605 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
1606 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1607
1608 const std::vector<float> input_to_forget_weights = {
1609 -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
1610 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
1611
1612 const std::vector<float> input_to_cell_weights = {
1613 -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1614 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
1615
1616 const std::vector<float> input_to_output_weights = {
1617 -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1618 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
1619
1620 const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1621
1622 const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1623
1624 const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1625
1626 const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1627
1628 const std::vector<float> recurrent_to_input_weights = {
1629 -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1630
1631 const std::vector<float> recurrent_to_cell_weights = {
1632 -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1633
1634 const std::vector<float> recurrent_to_forget_weights = {
1635 -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1636
1637 const std::vector<float> recurrent_to_output_weights = {
1638 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1639
1640 const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1641 const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1642 0.3};
1643 const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1644 const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1645 0.5};
1646
1647 const std::vector<float> projection_weights = {
1648 -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1649
1650 // Input ranges.
1651 const std::vector<std::pair<float, float>> ranges = {
1652 {-1.0, 127.0 / 128}, // input tensor
1653 {-1.0, 1.0}, // input_to_input_weight tensor
1654 {-1.0, 1.0}, // input_to_forget_weight tensor
1655 {-1.0, 1.0}, // input_to_cell_weight tensor
1656 {-1.0, 1.0}, // input_to_output_weight tensor
1657
1658 {-1.0, 1.0}, // recurrent_to_input_weight tensor
1659 {-1.0, 1.0}, // recurrent_to_forget_weight tensor
1660 {-1.0, 1.0}, // recurrent_to_cell_weight tensor
1661 {-1.0, 1.0}, // recurrent_to_output_weight tensor
1662
1663 {-1, 1}, // cell_to_input_weight tensor
1664 {-1, 1}, // cell_to_forget_weight tensor
1665 {-1, 1}, // cell_to_output_weight tensor
1666
1667 {-100, 100}, // input_gate_bias tensor
1668 {-100, 100}, // forget_gate_bias tensor
1669 {-100, 100}, // cell_gate_bias tensor
1670 {-100, 100}, // output_gate_bias tensor
1671
1672 {-0.5, 0.5}, // projection_weight tensor
1673 {-1, 1}, // projection_bias tensor
1674
1675 {-1.0, 32767.0 / 32768}, // output_state tensor
1676 {-1, 1}, // cell_state tensor
1677
1678 {-1.00001, 1.0}, // input_layer_norm_coefficient tensor
1679 {-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
1680 {-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
1681 {-1.00001, 1.0}, // output_layer_norm_coefficient tensor
1682 // Output scale is the same as output_state scale and only output_state
1683 // scale is used in the op, so this is only provided for clarity.
1684 {-1.0, 32767.0 / 32768}, // output tensor.
1685 };
1686
1687 // The scale and zero point of intermediate tensors.
1688 std::vector<std::pair<float, int>> intermediates = {
1689 {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
1690
1691 // Create model.
1692 LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
1693 /*use_cifg=*/false, /*use_peephole=*/false,
1694 /*use_projection_weights=*/true,
1695 /*use_projection_bias=*/false,
1696 /*use_layer_norm=*/true,
1697 /*use_8x8_8_implementation=*/false, ranges,
1698 intermediates);
1699 // Do allocate.
1700 lstm.PerformAllocateAndDelegate();
1701
1702 // Set weights.
1703 lstm.SetInputToInputWeights(input_to_input_weights);
1704 lstm.SetInputToCellWeights(input_to_cell_weights);
1705 lstm.SetInputToForgetWeights(input_to_forget_weights);
1706 lstm.SetInputToOutputWeights(input_to_output_weights);
1707
1708 lstm.SetInputGateBias(input_gate_bias);
1709 lstm.SetCellBias(cell_gate_bias);
1710 lstm.SetForgetGateBias(forget_gate_bias);
1711 lstm.SetOutputGateBias(output_gate_bias);
1712
1713 lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
1714 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
1715 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
1716 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
1717
1718 lstm.SetProjectionWeights(projection_weights);
1719
1720 lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
1721 lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
1722 lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
1723 lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
1724
1725 // Model inputs. sequence -batch - input
1726 const std::vector<std::vector<float>> lstm_input = {
1727 {
1728 0.7, 0.8, 0.1, 0.2, 0.3, //
1729 0.8, 0.1, 0.2, 0.4, 0.5, //
1730 },
1731 {
1732 0.2, 0.7, 0.7, 0.1, 0.7, //
1733 0.3, 0.2, 0.9, 0.8, 0.1, //
1734 },
1735 {
1736 0.7, 0.8, 0.1, 0.2, 0.3, //
1737 0.3, 0.2, 0.9, 0.8, 0.1, //
1738 },
1739 };
1740
1741 // Expected outputs.
1742 const std::vector<std::vector<int8_t>> expected_output = {
1743 {127, 127, -108, -67, 127, 127},
1744 {-128, 127, 127, -128, 127, 127},
1745 {127, 127, 127, -128, 127, 127},
1746 };
1747
1748 // Invoke and verify the result.
1749 const int input_sequence_size = lstm_input.size();
1750 EXPECT_GT(input_sequence_size, 0);
1751 for (int i = 0; i < input_sequence_size; ++i) {
1752 lstm.SetInput(lstm_input[i]);
1753 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1754 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
1755 }
1756 }
1757
TEST(IntegerLstmOpTest,NoCifg_Peephole_Projection_LayerNorm)1758 TEST(IntegerLstmOpTest, NoCifg_Peephole_Projection_LayerNorm) {
1759 // Hyper parameters.
1760 const int n_batch = 2;
1761 const int n_input = 5;
1762 const int n_cell = 4;
1763 const int n_output = 3;
1764
1765 // Model related weights.
1766 const std::vector<float> input_to_input_weights = {
1767 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
1768 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1769
1770 const std::vector<float> input_to_forget_weights = {
1771 -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
1772 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
1773
1774 const std::vector<float> input_to_cell_weights = {
1775 -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1776 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
1777
1778 const std::vector<float> input_to_output_weights = {
1779 -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1780 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
1781
1782 const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1783
1784 const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1785
1786 const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1787
1788 const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1789
1790 const std::vector<float> recurrent_to_input_weights = {
1791 -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1792
1793 const std::vector<float> recurrent_to_cell_weights = {
1794 -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1795
1796 const std::vector<float> recurrent_to_forget_weights = {
1797 -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1798
1799 const std::vector<float> recurrent_to_output_weights = {
1800 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1801
1802 const std::vector<float> cell_to_input_weights = {0.3, -0.1, 0.1, -0.2};
1803
1804 const std::vector<float> cell_to_forget_weights = {0.2, -0.1, 0.1, -0.2};
1805
1806 const std::vector<float> cell_to_output_weights = {0.3, -0.1, 0.1, -0.3};
1807
1808 const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1809 const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1810 0.3};
1811 const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1812 const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1813 0.5};
1814
1815 const std::vector<float> projection_weights = {
1816 -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1817
1818 // Input ranges.
1819 const std::vector<std::pair<float, float>> ranges = {
1820 {-1.0, 127.0 / 128}, // input tensor
1821 {-1.0, 1.0}, // input_to_input_weight tensor
1822 {-1.0, 1.0}, // input_to_forget_weight tensor
1823 {-1.0, 1.0}, // input_to_cell_weight tensor
1824 {-1.0, 1.0}, // input_to_output_weight tensor
1825
1826 {-1.0, 1.0}, // recurrent_to_input_weight tensor
1827 {-0.9, 0.9}, // recurrent_to_forget_weight tensor
1828 {-1.0, 1.0}, // recurrent_to_cell_weight tensor
1829 {-1.0, 1.0}, // recurrent_to_output_weight tensor
1830
1831 {-0.3, 0.3}, // cell_to_input_weight tensor
1832 {-0.3, 0.3}, // cell_to_forget_weight tensor
1833 {-0.3, 0.3}, // cell_to_output_weight tensor
1834
1835 {-100, 100}, // input_gate_bias tensor
1836 {-100, 80}, // forget_gate_bias tensor
1837 {-100, 100}, // cell_gate_bias tensor
1838 {-100, 100}, // output_gate_bias tensor
1839
1840 {-0.5, 0.5}, // projection_weight tensor
1841 {-1, 1}, // projection_bias tensor
1842
1843 {-1.0, 32767.0 / 32768}, // output_state tensor
1844 {-1, 1}, // cell_state tensor
1845
1846 {-0.5, 0.5}, // input_layer_norm_coefficient tensor
1847 {-0.5, 0.5}, // forget_layer_norm_coefficient tensor
1848 {-1.0, 1.0}, // cell_layer_norm_coefficient tensor
1849 {-1.0, 1.0}, // output_layer_norm_coefficient tensor
1850 // Output scale is the same as output_state scale and only output_state
1851 // scale is used in the op, so this is only provided for clarity.
1852 {-1.0, 32767.0 / 32768}, // output tensor.
1853 };
1854
1855 // The scale and zero point of intermediate tensors.
1856 std::vector<std::pair<float, int>> intermediates = {
1857 {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
1858
1859 // Create model.
1860 LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
1861 /*use_cifg=*/false, /*use_peephole=*/true,
1862 /*use_projection_weights=*/true,
1863 /*use_projection_bias=*/false,
1864 /*use_layer_norm=*/true,
1865 /*use_8x8_8_implementation=*/false, ranges,
1866 intermediates);
1867
1868 // Do allocate.
1869 lstm.PerformAllocateAndDelegate();
1870
1871 // Set weights.
1872 lstm.SetInputToInputWeights(input_to_input_weights);
1873 lstm.SetInputToCellWeights(input_to_cell_weights);
1874 lstm.SetInputToForgetWeights(input_to_forget_weights);
1875 lstm.SetInputToOutputWeights(input_to_output_weights);
1876
1877 lstm.SetInputGateBias(input_gate_bias);
1878 lstm.SetCellBias(cell_gate_bias);
1879 lstm.SetForgetGateBias(forget_gate_bias);
1880 lstm.SetOutputGateBias(output_gate_bias);
1881
1882 lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
1883 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
1884 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
1885 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
1886
1887 lstm.SetCellToInputWeights(cell_to_input_weights);
1888 lstm.SetCellToForgetWeights(cell_to_forget_weights);
1889 lstm.SetCellToOutputWeights(cell_to_output_weights);
1890
1891 lstm.SetProjectionWeights(projection_weights);
1892
1893 lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
1894 lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
1895 lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
1896 lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
1897
1898 // Model inputs. sequence -batch - input
1899 const std::vector<std::vector<float>> lstm_input = {
1900 {
1901 0.7, 0.8, 0.1, 0.2, 0.3, //
1902 0.8, 0.1, 0.2, 0.4, 0.5, //
1903 },
1904 {
1905 0.2, 0.7, 0.7, 0.1, 0.7, //
1906 0.3, 0.2, 0.9, 0.8, 0.1, //
1907 },
1908 {
1909 0.7, 0.8, 0.1, 0.2, 0.3, //
1910 0.3, 0.2, 0.9, 0.8, 0.1, //
1911 },
1912 };
1913
1914 // Expected outputs.
1915 const std::vector<std::vector<int8_t>> expected_output = {
1916 {127, 127, -16, -21, 127, 127},
1917 {23, 127, 127, -128, 127, 127},
1918 {127, 127, 127, -128, 127, 127},
1919 };
1920
1921 // Invoke and verify the result.
1922 const int input_sequence_size = lstm_input.size();
1923 EXPECT_GT(input_sequence_size, 0);
1924 for (int i = 0; i < input_sequence_size; ++i) {
1925 lstm.SetInput(lstm_input[i]);
1926 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1927 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
1928 }
1929 }
1930
TEST(IntegerLstmOpTest,Cifg_NoPeephole_Projection_LayerNorm_8x8_8)1931 TEST(IntegerLstmOpTest, Cifg_NoPeephole_Projection_LayerNorm_8x8_8) {
1932 // Hyper parameters.
1933 const int n_batch = 2;
1934 const int n_input = 5;
1935 const int n_cell = 4;
1936 const int n_output = 3;
1937
1938 // Model related weights.
1939 const std::vector<float> input_to_input_weights = {
1940 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
1941 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1942
1943 const std::vector<float> input_to_forget_weights = {
1944 -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
1945 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
1946
1947 const std::vector<float> input_to_cell_weights = {
1948 -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1949 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
1950
1951 const std::vector<float> input_to_output_weights = {
1952 -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1953 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
1954
1955 const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1956
1957 const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1958
1959 const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1960
1961 const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1962
1963 const std::vector<float> recurrent_to_input_weights = {
1964 -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1965
1966 const std::vector<float> recurrent_to_cell_weights = {
1967 -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1968
1969 const std::vector<float> recurrent_to_forget_weights = {
1970 -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1971
1972 const std::vector<float> recurrent_to_output_weights = {
1973 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1974
1975 const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1976 const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1977 0.3};
1978 const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1979 const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1980 0.5};
1981
1982 const std::vector<float> projection_weights = {
1983 -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1984 const std::vector<float> projection_bias = {0.1, 0.3, 0.5};
1985
1986 // Input ranges.
1987 const std::vector<std::pair<float, float>> ranges = {
1988 {-1.0, 127.0 / 128}, // input tensor
1989 {-1.0, 1.0}, // input_to_input_weight tensor
1990 {-1.0, 1.0}, // input_to_forget_weight tensor
1991 {-1.0, 1.0}, // input_to_cell_weight tensor
1992 {-1.0, 1.0}, // input_to_output_weight tensor
1993
1994 {-1.0, 1.0}, // recurrent_to_input_weight tensor
1995 {-1.0, 1.0}, // recurrent_to_forget_weight tensor
1996 {-1.0, 1.0}, // recurrent_to_cell_weight tensor
1997 {-1.0, 1.0}, // recurrent_to_output_weight tensor
1998
1999 {-1, 1}, // cell_to_input_weight tensor
2000 {-1, 1}, // cell_to_forget_weight tensor
2001 {-1, 1}, // cell_to_output_weight tensor
2002
2003 {-100, 100}, // input_gate_bias tensor
2004 {-100, 100}, // forget_gate_bias tensor
2005 {-100, 100}, // cell_gate_bias tensor
2006 {-100, 100}, // output_gate_bias tensor
2007
2008 {-0.5, 0.5}, // projection_weight tensor
2009 {-1, 1}, // projection_bias tensor
2010
2011 {-1.0, 32767.0 / 32768}, // output_state tensor
2012 {-1.0, 32767.0 / 32768}, // cell_state tensor
2013
2014 {-1.00001, 1.0}, // input_layer_norm_coefficient tensor
2015 {-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
2016 {-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
2017 {-1.00001, 1.0}, // output_layer_norm_coefficient tensor
2018 // Output scale is the same as output_state scale and only output_state
2019 // scale is used in the op, so this is only provided for clarity.
2020 {-1.0, 32767.0 / 32768}, // output tensor.
2021 };
2022
2023 // The scale and zero point of intermediate tensors.
2024 std::vector<std::pair<float, int>> intermediates = {
2025 {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0},
2026 {0.007, 0}, {0.007059, 0}, {0.007, 0}, {0.007, 0},
2027 {0.007059, 0}, {0.007, 0}, {0.007, 0}, {0.3, 0}};
2028
2029 // Create model.
2030 LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
2031 /*use_cifg=*/true, /*use_peephole=*/false,
2032 /*use_projection_weights=*/true,
2033 /*use_projection_bias=*/true,
2034 /*use_layer_norm=*/true,
2035 /*use_8x8_8_implementation=*/true, ranges,
2036 intermediates);
2037
2038 // Do allocate.
2039 lstm.PerformAllocateAndDelegate();
2040
2041 // Set weights.
2042 // lstm.SetInputToInputWeights(input_to_input_weights);
2043 lstm.SetInputToCellWeights(input_to_cell_weights);
2044 lstm.SetInputToForgetWeights(input_to_forget_weights);
2045 lstm.SetInputToOutputWeights(input_to_output_weights);
2046
2047 // lstm.SetInputGateBias(input_gate_bias);
2048 lstm.SetCellBias(cell_gate_bias);
2049 lstm.SetForgetGateBias(forget_gate_bias);
2050 lstm.SetOutputGateBias(output_gate_bias);
2051
2052 // lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
2053 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
2054 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
2055 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
2056
2057 lstm.SetProjectionWeights(projection_weights);
2058 lstm.SetProjectionBias(projection_bias);
2059
2060 // lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
2061 lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
2062 lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
2063 lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
2064
2065 // Model inputs. sequence -batch - input
2066 const std::vector<std::vector<float>> lstm_input = {
2067 {
2068 0.7, 0.8, 0.1, 0.2, 0.3, //
2069 0.8, 0.1, 0.2, 0.4, 0.5, //
2070 },
2071 {
2072 0.2, 0.7, 0.7, 0.1, 0.7, //
2073 0.3, 0.2, 0.9, 0.8, 0.1, //
2074 },
2075 {
2076 0.7, 0.8, 0.1, 0.2, 0.3, //
2077 0.3, 0.2, 0.9, 0.8, 0.1, //
2078 },
2079 };
2080
2081 // Expected outputs.
2082 const std::vector<std::vector<int8_t>> expected_output = {
2083 {127, 127, 127, 127, 127, 127},
2084 {127, 127, 127, 127, 127, 127},
2085 {127, 127, 127, 127, 127, 127},
2086 };
2087
2088 // Invoke and verify the result.
2089 const int input_sequence_size = lstm_input.size();
2090 EXPECT_GT(input_sequence_size, 0);
2091 for (int i = 0; i < input_sequence_size; ++i) {
2092 lstm.SetInput(lstm_input[i]);
2093 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2094 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
2095 }
2096 }
2097
2098 #ifdef GTEST_HAS_DEATH_TEST
TEST(LstmOpTest,InvalidTypes)2099 TEST(LstmOpTest, InvalidTypes) {
2100 const int n_batch = 1;
2101 const int n_input = 2;
2102 const int n_cell = 4;
2103 const int n_output = 4;
2104
2105 EXPECT_DEATH(LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2106 /*use_cifg=*/false, /*use_peephole=*/false,
2107 /*use_projection_weights=*/false,
2108 /*use_projection_bias=*/false,
2109 /*weight_type=*/TensorType_INT32,
2110 /*model_has_legacy_20_inputs=*/true,
2111 /*is_layer_norm=*/false,
2112 /*asymmetric_quantize_inputs=*/false),
2113 "");
2114
2115 EXPECT_DEATH(LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2116 /*use_cifg=*/false, /*use_peephole=*/false,
2117 /*use_projection_weights=*/false,
2118 /*use_projection_bias=*/false,
2119 /*weight_type=*/TensorType_COMPLEX64,
2120 /*model_has_legacy_20_inputs=*/true,
2121 /*is_layer_norm=*/false,
2122 /*asymmetric_quantize_inputs=*/false),
2123 "");
2124 }
2125 #endif
2126
2127 class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel {
2128 public:
HybridSparseLSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorData & input_weights_td,const std::vector<float> & input_to_input_weights,const std::vector<float> & input_to_forget_weights,const std::vector<float> & input_to_cell_weights,const std::vector<float> & input_to_output_weights,const TensorData & recurrent_weights_td,const std::vector<float> & recurrent_to_input_weights,const std::vector<float> & recurrent_to_forget_weights,const std::vector<float> & recurrent_to_cell_weights,const std::vector<float> & recurrent_to_output_weights,const::tflite::TensorType & weight_type=::tflite::TensorType_INT8)2129 HybridSparseLSTMOpModel(
2130 int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
2131 bool use_peephole, bool use_projection_weights, bool use_projection_bias,
2132 float cell_clip, float proj_clip,
2133 const std::vector<std::vector<int>>& input_shapes,
2134 const TensorData& input_weights_td,
2135 const std::vector<float>& input_to_input_weights,
2136 const std::vector<float>& input_to_forget_weights,
2137 const std::vector<float>& input_to_cell_weights,
2138 const std::vector<float>& input_to_output_weights,
2139 const TensorData& recurrent_weights_td,
2140 const std::vector<float>& recurrent_to_input_weights,
2141 const std::vector<float>& recurrent_to_forget_weights,
2142 const std::vector<float>& recurrent_to_cell_weights,
2143 const std::vector<float>& recurrent_to_output_weights,
2144 const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8)
2145 : n_batch_(n_batch),
2146 n_input_(n_input),
2147 n_cell_(n_cell),
2148 n_output_(n_output) {
2149 input_ = AddInput(::tflite::TensorType_FLOAT32);
2150
2151 if (use_cifg) {
2152 input_to_input_weights_ = AddNullInput();
2153 } else {
2154 input_to_input_weights_ =
2155 AddConstSparseInput(input_weights_td, input_to_input_weights, true);
2156 }
2157
2158 input_to_forget_weights_ =
2159 AddConstSparseInput(input_weights_td, input_to_forget_weights, true);
2160
2161 input_to_cell_weights_ =
2162 AddConstSparseInput(input_weights_td, input_to_cell_weights, true);
2163
2164 input_to_output_weights_ =
2165 AddConstSparseInput(input_weights_td, input_to_output_weights, true);
2166
2167 if (use_cifg) {
2168 recurrent_to_input_weights_ = AddNullInput();
2169 } else {
2170 recurrent_to_input_weights_ = AddConstSparseInput(
2171 recurrent_weights_td, recurrent_to_input_weights, true);
2172 }
2173
2174 recurrent_to_forget_weights_ = AddConstSparseInput(
2175 recurrent_weights_td, recurrent_to_forget_weights, true);
2176 recurrent_to_cell_weights_ = AddConstSparseInput(
2177 recurrent_weights_td, recurrent_to_cell_weights, true);
2178 recurrent_to_output_weights_ = AddConstSparseInput(
2179 recurrent_weights_td, recurrent_to_output_weights, true);
2180
2181 if (use_peephole) {
2182 if (use_cifg) {
2183 cell_to_input_weights_ = AddNullInput();
2184 } else {
2185 cell_to_input_weights_ = AddInput(weight_type);
2186 }
2187 cell_to_forget_weights_ = AddInput(weight_type);
2188 cell_to_output_weights_ = AddInput(weight_type);
2189 } else {
2190 cell_to_input_weights_ = AddNullInput();
2191 cell_to_forget_weights_ = AddNullInput();
2192 cell_to_output_weights_ = AddNullInput();
2193 }
2194
2195 if (use_cifg) {
2196 input_gate_bias_ = AddNullInput();
2197 } else {
2198 input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2199 }
2200 forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2201 cell_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2202 output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2203
2204 if (use_projection_weights) {
2205 projection_weights_ = AddInput(weight_type);
2206 if (use_projection_bias) {
2207 projection_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2208 } else {
2209 projection_bias_ = AddNullInput();
2210 }
2211 } else {
2212 projection_weights_ = AddNullInput();
2213 projection_bias_ = AddNullInput();
2214 }
2215
2216 // Adding the 2 state tensors.
2217 output_state_ = AddVariableInput(::tflite::TensorData{
2218 ::tflite::TensorType_FLOAT32, {n_output_ * n_batch_}});
2219 cell_state_ = AddVariableInput(::tflite::TensorData{
2220 ::tflite::TensorType_FLOAT32, {n_cell_ * n_batch_}});
2221
2222 if (use_cifg) {
2223 input_layer_norm_weights_ = AddNullInput();
2224 } else {
2225 input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2226 }
2227 forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2228 cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2229 output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2230
2231 output_ = AddOutput(::tflite::TensorType_FLOAT32);
2232
2233 SetBuiltinOp(
2234 BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
2235 CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
2236 proj_clip, LSTMKernelType_FULL, false)
2237 .Union());
2238 BuildInterpreter(input_shapes);
2239 }
2240
SetCellToInputWeights(std::vector<float> f)2241 void SetCellToInputWeights(std::vector<float> f) {
2242 SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
2243 }
2244
SetCellToForgetWeights(std::vector<float> f)2245 void SetCellToForgetWeights(std::vector<float> f) {
2246 SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
2247 }
2248
SetCellToOutputWeights(std::vector<float> f)2249 void SetCellToOutputWeights(std::vector<float> f) {
2250 SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
2251 }
2252
SetInputLayerNormWeights(std::vector<float> f)2253 void SetInputLayerNormWeights(std::vector<float> f) {
2254 PopulateTensor(input_layer_norm_weights_, f);
2255 }
2256
SetForgetLayerNormWeights(std::vector<float> f)2257 void SetForgetLayerNormWeights(std::vector<float> f) {
2258 PopulateTensor(forget_layer_norm_weights_, f);
2259 }
2260
SetCellLayerNormWeights(std::vector<float> f)2261 void SetCellLayerNormWeights(std::vector<float> f) {
2262 PopulateTensor(cell_layer_norm_weights_, f);
2263 }
2264
SetOutputLayerNormWeights(std::vector<float> f)2265 void SetOutputLayerNormWeights(std::vector<float> f) {
2266 PopulateTensor(output_layer_norm_weights_, f);
2267 }
2268
SetInputGateBias(std::vector<float> f)2269 void SetInputGateBias(std::vector<float> f) {
2270 PopulateTensor(input_gate_bias_, f);
2271 }
2272
SetForgetGateBias(std::vector<float> f)2273 void SetForgetGateBias(std::vector<float> f) {
2274 PopulateTensor(forget_gate_bias_, f);
2275 }
2276
SetCellBias(std::vector<float> f)2277 void SetCellBias(std::vector<float> f) { PopulateTensor(cell_bias_, f); }
2278
SetOutputGateBias(std::vector<float> f)2279 void SetOutputGateBias(std::vector<float> f) {
2280 PopulateTensor(output_gate_bias_, f);
2281 }
2282
SetProjectionWeights(std::vector<float> f)2283 void SetProjectionWeights(std::vector<float> f) {
2284 SignedSymmetricQuantizeAndPopulate(projection_weights_, f);
2285 }
2286
SetProjectionBias(std::vector<float> f)2287 void SetProjectionBias(std::vector<float> f) {
2288 PopulateTensor(projection_bias_, f);
2289 }
2290
SetInput(int offset,const float * begin,const float * end)2291 void SetInput(int offset, const float* begin, const float* end) {
2292 PopulateTensor(input_, offset, const_cast<float*>(begin),
2293 const_cast<float*>(end));
2294 }
2295
GetOutput()2296 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2297
num_inputs()2298 int num_inputs() { return n_input_; }
num_outputs()2299 int num_outputs() { return n_output_; }
num_cells()2300 int num_cells() { return n_cell_; }
num_batches()2301 int num_batches() { return n_batch_; }
2302
2303 protected:
2304 int input_;
2305 int input_to_input_weights_;
2306 int input_to_forget_weights_;
2307 int input_to_cell_weights_;
2308 int input_to_output_weights_;
2309
2310 int recurrent_to_input_weights_;
2311 int recurrent_to_forget_weights_;
2312 int recurrent_to_cell_weights_;
2313 int recurrent_to_output_weights_;
2314
2315 int cell_to_input_weights_;
2316 int cell_to_forget_weights_;
2317 int cell_to_output_weights_;
2318
2319 int input_layer_norm_weights_;
2320 int forget_layer_norm_weights_;
2321 int cell_layer_norm_weights_;
2322 int output_layer_norm_weights_;
2323
2324 int input_gate_bias_;
2325 int forget_gate_bias_;
2326 int cell_bias_;
2327 int output_gate_bias_;
2328
2329 int projection_weights_;
2330 int projection_bias_;
2331
2332 int output_state_;
2333 int cell_state_;
2334
2335 int output_;
2336
2337 int n_batch_;
2338 int n_input_;
2339 int n_cell_;
2340 int n_output_;
2341 };
2342
2343 class BaseSparseLstmTest : public ::testing::Test {
2344 protected:
2345 // Weights of the Sparse Layer Norm LSTM model. Some are optional.
2346 std::vector<float> input_to_input_weights_;
2347 std::vector<float> input_to_cell_weights_;
2348 std::vector<float> input_to_forget_weights_;
2349 std::vector<float> input_to_output_weights_;
2350 std::vector<float> input_gate_bias_;
2351 std::vector<float> cell_gate_bias_;
2352 std::vector<float> forget_gate_bias_;
2353 std::vector<float> output_gate_bias_;
2354 std::vector<float> recurrent_to_input_weights_;
2355 std::vector<float> recurrent_to_cell_weights_;
2356 std::vector<float> recurrent_to_forget_weights_;
2357 std::vector<float> recurrent_to_output_weights_;
2358 std::vector<float> cell_to_input_weights_;
2359 std::vector<float> cell_to_forget_weights_;
2360 std::vector<float> cell_to_output_weights_;
2361 std::vector<float> input_layer_norm_weights_;
2362 std::vector<float> forget_layer_norm_weights_;
2363 std::vector<float> cell_layer_norm_weights_;
2364 std::vector<float> output_layer_norm_weights_;
2365 std::vector<float> projection_weights_;
2366
2367 std::vector<int> input_to_input_weights_size_;
2368 std::vector<int> input_to_cell_weights_size_;
2369 std::vector<int> input_to_forget_weights_size_;
2370 std::vector<int> input_to_output_weights_size_;
2371 std::vector<int> recurrent_to_input_weights_size_;
2372 std::vector<int> recurrent_to_cell_weights_size_;
2373 std::vector<int> recurrent_to_forget_weights_size_;
2374 std::vector<int> recurrent_to_output_weights_size_;
2375
2376 int n_batch_;
2377 int n_input_;
2378 int n_cell_;
2379 int n_output_;
2380 float cell_clip_;
2381 float proj_clip_;
2382
2383 // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
2384 std::vector<std::vector<float>> sparse_layer_norm_lstm_input_;
2385
2386 // Compares output up to tolerance to the result of the layer_norm_lstm given
2387 // the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,HybridSparseLSTMOpModel * sparse_layer_norm_lstm,float tolerance=1e-5)2388 void VerifyGoldens(const std::vector<std::vector<float>>& input,
2389 const std::vector<std::vector<float>>& output,
2390 HybridSparseLSTMOpModel* sparse_layer_norm_lstm,
2391 float tolerance = 1e-5) {
2392 const int num_batches = input.size();
2393 EXPECT_GT(num_batches, 0);
2394 const int num_inputs = sparse_layer_norm_lstm->num_inputs();
2395 EXPECT_GT(num_inputs, 0);
2396 const int input_sequence_size = input[0].size() / num_inputs;
2397 EXPECT_GT(input_sequence_size, 0);
2398 for (int i = 0; i < input_sequence_size; ++i) {
2399 for (int b = 0; b < num_batches; ++b) {
2400 const float* batch_start = input[b].data() + i * num_inputs;
2401 const float* batch_end = batch_start + num_inputs;
2402
2403 sparse_layer_norm_lstm->SetInput(
2404 b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end);
2405 }
2406
2407 ASSERT_EQ(sparse_layer_norm_lstm->Invoke(), kTfLiteOk);
2408
2409 const int num_outputs = sparse_layer_norm_lstm->num_outputs();
2410 std::vector<float> expected;
2411 for (int b = 0; b < num_batches; ++b) {
2412 const float* golden_start_batch = output[b].data() + i * num_outputs;
2413 const float* golden_end_batch = golden_start_batch + num_outputs;
2414 expected.insert(expected.end(), golden_start_batch, golden_end_batch);
2415 }
2416 EXPECT_THAT(
2417 sparse_layer_norm_lstm->GetOutput(),
2418 ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance)));
2419 }
2420 }
2421 };
2422
2423 class NoCifgPeepholeProjectionNoClippingSparseLstmTest
2424 : public BaseSparseLstmTest {
SetUp()2425 void SetUp() override {
2426 n_batch_ = 2;
2427 n_input_ = 48;
2428 n_cell_ = 4;
2429 n_output_ = 16;
2430 cell_clip_ = 0.0;
2431 proj_clip_ = 0.0;
2432
2433 /* clang-format off */
2434 input_to_input_weights_ = {
2435 /* 1st row */
2436 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2437 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2438 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2439 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2440 /* 2nd row */
2441 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2442 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2443 -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2444 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2445 /* 3rd row */
2446 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2447 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2448 -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2449 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2450 /* 4th row */
2451 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2452 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2453 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2454 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2455 input_to_input_weights_size_ = {4, 48};
2456
2457 input_to_forget_weights_ = {
2458 /* 1st row */
2459 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2460 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2461 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2462 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2463 /* 2nd row */
2464 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2465 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2466 -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2467 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2468 /* 3rd row */
2469 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2470 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2471 -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2472 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2473 /* 4th row */
2474 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2475 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2476 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2477 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2478 input_to_forget_weights_size_ = {4, 48};
2479
2480 input_to_cell_weights_ = {
2481 /* 1st row */
2482 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2483 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2484 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2485 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2486 /* 2nd row */
2487 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2488 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2489 -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2490 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2491 /* 3rd row */
2492 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2493 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2494 -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2495 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2496 /* 4th row */
2497 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2498 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2499 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2500 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2501 input_to_cell_weights_size_ = {4, 48};
2502
2503 input_to_output_weights_ = {
2504 /* 1st row */
2505 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2506 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2507 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2508 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2509 /* 2nd row */
2510 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2511 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2512 -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2513 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2514 /* 3rd row */
2515 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2516 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2517 -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2518 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2519 /* 4th row */
2520 -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2521 -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2522 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2523 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2524 input_to_output_weights_size_ = {4, 48};
2525
2526 input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
2527
2528 forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
2529
2530 cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
2531
2532 output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
2533
2534 recurrent_to_input_weights_ = {
2535 -0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2536 0.0, 0.0, // 1st row
2537 0.1, -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2538 0.0, 0.0, // 2nd row
2539 -0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2540 0.0, 0.0, // 3rd row
2541 0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2542 0.0, 0.0, // 4th row
2543 };
2544 recurrent_to_input_weights_size_ = {4, 16};
2545
2546 recurrent_to_cell_weights_ = {
2547 -0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2548 0.0, 0.0, // 1st row
2549 -0.3, 0.8, -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2550 0.0, 0.0, // 2nd row
2551 -0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2552 0.0, 0.0, // 3rd row
2553 -0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2554 0.0, 0.0, // 4th row
2555 };
2556 recurrent_to_cell_weights_size_ = {4, 16};
2557
2558 recurrent_to_forget_weights_ = {
2559 -0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2560 0.0, 0.0, // 1st row
2561 -0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2562 0.0, 0.0, // 2nd row
2563 0.9, 0.3, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2564 0.0, 0.0, // 3rd row
2565 0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2566 0.0, 0.0, // 4th row
2567 };
2568 recurrent_to_forget_weights_size_ = {4, 16};
2569
2570 recurrent_to_output_weights_ = {
2571 0.3, -0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2572 0.0, 0.0, // 1st row
2573 -0.2, -0.5, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2574 0.0, 0.0, // 2nd row
2575 -0.2, -0.6, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2576 0.0, 0.0, // 3rd row
2577 -0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2578 0.0, 0.0, // 4th row
2579 };
2580 recurrent_to_output_weights_size_ = {4, 16};
2581
2582 cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
2583
2584 cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
2585
2586 cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
2587
2588 input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
2589 forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
2590 cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
2591 output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
2592
2593 projection_weights_ = {
2594 -0.1, 0.2, 0.01, -0.2, // 1st row
2595 0.1, 0.5, 0.3, 0.08, // 2nd row
2596 0.07, 0.2, -0.4, 0.2, // 3rd row
2597 0.0, 0.0, 0.0, 0.0, // 4th row
2598 0.0, 0.0, 0.0, 0.0, // 5th row
2599 0.0, 0.0, 0.0, 0.0, // 6th row
2600 0.0, 0.0, 0.0, 0.0, // 7th row
2601 0.0, 0.0, 0.0, 0.0, // 8th row
2602 0.0, 0.0, 0.0, 0.0, // 9th row
2603 0.0, 0.0, 0.0, 0.0, // 10th row
2604 0.0, 0.0, 0.0, 0.0, // 11th row
2605 0.0, 0.0, 0.0, 0.0, // 12th row
2606 0.0, 0.0, 0.0, 0.0, // 13th row
2607 0.0, 0.0, 0.0, 0.0, // 14th row
2608 0.0, 0.0, 0.0, 0.0, // 15th row
2609 0.0, 0.0, 0.0, 0.0, // 16th row
2610 };
2611
2612 sparse_layer_norm_lstm_input_ = {
2613 // Batch0: 2 (input_sequence_size) * 45 (n_input_)
2614 {
2615 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2616 -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
2617 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2618 -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
2619 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
2620 -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
2621 0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
2622 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // seq 1
2623 },
2624 // Batch1: 2 (input_sequence_size) * 45 (n_input_)
2625 {
2626 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2627 -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
2628 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2629 -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
2630 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
2631 -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
2632 0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
2633 1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0, // seq 1
2634 },
2635 };
2636 /* clang-format on */
2637 }
2638 };
2639
TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,HybridSparseLstmBlackBoxTest)2640 TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,
2641 HybridSparseLstmBlackBoxTest) {
2642 TensorData input_weight = {};
2643 input_weight.type = TensorType_FLOAT32;
2644 input_weight.shape = {4, 48};
2645 input_weight.traversal_order = {0, 1, 2};
2646 input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2647 input_weight.block_map = {1};
2648 input_weight.block_size = {16};
2649 TensorData recurrent_weight = {};
2650 recurrent_weight.type = TensorType_FLOAT32;
2651 recurrent_weight.shape = {4, 16};
2652 recurrent_weight.traversal_order = {0, 1, 2};
2653 recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2654 recurrent_weight.block_map = {1};
2655 recurrent_weight.block_size = {16};
2656 HybridSparseLSTMOpModel sparse_layer_norm_lstm(
2657 n_batch_, n_input_, n_cell_, n_output_,
2658 /*use_cifg=*/false, /*use_peephole=*/true,
2659 /*use_projection_weights=*/true,
2660 /*use_projection_bias=*/false, cell_clip_, proj_clip_,
2661 {
2662 {n_batch_, n_input_}, // input tensor
2663
2664 {input_to_input_weights_size_},
2665 {input_to_forget_weights_size_},
2666 {input_to_cell_weights_size_},
2667 {input_to_output_weights_size_},
2668
2669 {recurrent_to_input_weights_size_},
2670 {recurrent_to_forget_weights_size_},
2671 {recurrent_to_cell_weights_size_},
2672 {recurrent_to_output_weights_size_},
2673
2674 {n_cell_}, // cell_to_input_weight tensor
2675 {n_cell_}, // cell_to_forget_weight tensor
2676 {n_cell_}, // cell_to_output_weight tensor
2677
2678 {n_cell_}, // input_gate_bias tensor
2679 {n_cell_}, // forget_gate_bias tensor
2680 {n_cell_}, // cell_bias tensor
2681 {n_cell_}, // output_gate_bias tensor
2682
2683 {n_output_, n_cell_}, // projection_weight tensor
2684 {0}, // projection_bias tensor
2685
2686 {n_output_ * n_batch_}, // output_state tensor
2687 {n_cell_ * n_batch_}, // cell_state tensor
2688
2689 {n_cell_}, // input_layer_norm_weight tensor
2690 {n_cell_}, // forget_layer_norm_weight tensor
2691 {n_cell_}, // cell_layer_norm_weight tensor
2692 {n_cell_}, // output_layer_norm_weight tensor
2693 },
2694 input_weight, input_to_input_weights_, input_to_forget_weights_,
2695 input_to_cell_weights_, input_to_output_weights_, recurrent_weight,
2696 recurrent_to_input_weights_, recurrent_to_forget_weights_,
2697 recurrent_to_cell_weights_, recurrent_to_output_weights_);
2698
2699 sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_);
2700 sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_);
2701 sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
2702 sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
2703
2704 sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
2705 sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
2706 sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
2707
2708 sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
2709 sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
2710 sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
2711 sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
2712
2713 sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_);
2714
2715 /* clang-format off */
2716 const std::vector<std::vector<float>> sparse_layer_norm_lstm_golden_output = {
2717 {
2718 // Batch0: 2 (input_sequence_size) * 3 (n_output_)
2719 0.0559981, 0.140761, -0.0618812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2720 0.0, 0.0, 0.0, 0.0, 0.0,
2721 0.070831, 0.200455, -0.0581763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2722 0.0, 0.0, 0.0, 0.0, 0.0,
2723 },
2724 {
2725 // Batch1: 3 (input_sequence_size) * 3 (n_output_)
2726 0.0559981, 0.140761, -0.0618812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2727 0.0, 0.0, 0.0, 0.0, 0.0,
2728 0.070831, 0.200455, -0.0581763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2729 0.0, 0.0, 0.0, 0.0, 0.0,
2730 }};
2731 /* clang-format on */
2732
2733 VerifyGoldens(sparse_layer_norm_lstm_input_,
2734 sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm);
2735 }
2736
2737 // Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
2738 INSTANTIATE_TEST_SUITE_P(
2739 Parameterized, LstmOpTest,
2740 ::testing::Combine(::testing::Values(TensorType_FLOAT32, TensorType_UINT8,
2741 TensorType_INT8),
2742 ::testing::Bool(), ::testing::Bool()));
2743
2744 } // namespace
2745 } // namespace tflite
2746