xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/lstm_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite LSTM op.
16 //
17 // TODO(alanchiao): add unit test with invalid input dimensions for this and its
18 // variants.
19 
20 #include <stdint.h>
21 
22 #include <utility>
23 #include <vector>
24 
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/test_util.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31 
32 namespace tflite {
33 namespace {
34 
35 using ::testing::ElementsAreArray;
36 
37 class LSTMOpModel : public SingleOpModel {
38  public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,const TensorType weight_type,bool model_has_legacy_20_inputs,bool is_layer_norm,bool asymmetric_quantize_inputs)39   LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
40               bool use_peephole, bool use_projection_weights,
41               bool use_projection_bias, const TensorType weight_type,
42               bool model_has_legacy_20_inputs, bool is_layer_norm,
43               bool asymmetric_quantize_inputs)
44       : n_input_(n_input),
45         n_output_(n_output),
46         n_batch_(n_batch),
47         weight_type_(weight_type) {
48     input_ = AddInput({TensorType_FLOAT32, {n_batch, n_input}});
49 
50     if (use_cifg) {
51       input_to_input_weights_ = AddNullInput();
52     } else {
53       input_to_input_weights_ = AddInput({weight_type, {n_cell, n_input}});
54     }
55     input_to_forget_weights_ = AddInput({weight_type, {n_cell, n_input}});
56     input_to_cell_weights_ = AddInput({weight_type, {n_cell, n_input}});
57     input_to_output_weights_ = AddInput({weight_type, {n_cell, n_input}});
58 
59     if (use_cifg) {
60       recurrent_to_input_weights_ = AddNullInput();
61     } else {
62       recurrent_to_input_weights_ = AddInput({weight_type, {n_cell, n_output}});
63     }
64     recurrent_to_forget_weights_ = AddInput({weight_type, {n_cell, n_output}});
65     recurrent_to_cell_weights_ = AddInput({weight_type, {n_cell, n_output}});
66     recurrent_to_output_weights_ = AddInput({weight_type, {n_cell, n_output}});
67 
68     if (use_peephole) {
69       if (use_cifg) {
70         cell_to_input_weights_ = AddNullInput();
71       } else {
72         cell_to_input_weights_ = AddInput({weight_type, {n_cell}});
73       }
74       cell_to_forget_weights_ = AddInput({weight_type, {n_cell}});
75       cell_to_output_weights_ = AddInput({weight_type, {n_cell}});
76     } else {
77       cell_to_input_weights_ = AddNullInput();
78       cell_to_forget_weights_ = AddNullInput();
79       cell_to_output_weights_ = AddNullInput();
80     }
81 
82     if (use_cifg) {
83       input_gate_bias_ = AddNullInput();
84     } else {
85       input_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
86     }
87     forget_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
88     cell_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
89     output_gate_bias_ = AddInput({TensorType_FLOAT32, {n_cell}});
90 
91     if (use_projection_weights) {
92       projection_weights_ = AddInput({weight_type, {n_output, n_cell}});
93     } else {
94       projection_weights_ = AddNullInput();
95     }
96     if (use_projection_bias) {
97       CHECK(use_projection_weights);
98       projection_bias_ = AddInput({TensorType_FLOAT32, {n_output}});
99     } else {
100       projection_bias_ = AddNullInput();
101     }
102 
103     // Adding the 2 state tensors.
104     AddVariableInput({TensorType_FLOAT32, {n_batch, n_output}});
105     AddVariableInput({TensorType_FLOAT32, {n_batch, n_cell}});
106 
107     // Layer norm weights.
108     if (!model_has_legacy_20_inputs) {
109       if (is_layer_norm) {
110         if (use_cifg) {
111           input_layer_norm_coefficients_ = AddNullInput();
112         } else {
113           input_layer_norm_coefficients_ =
114               AddInput({TensorType_FLOAT32, {n_cell}});
115         }
116         forget_layer_norm_coefficients_ =
117             AddInput({TensorType_FLOAT32, {n_cell}});
118         cell_layer_norm_coefficients_ =
119             AddInput({TensorType_FLOAT32, {n_cell}});
120         output_layer_norm_coefficients_ =
121             AddInput({TensorType_FLOAT32, {n_cell}});
122       } else {
123         input_layer_norm_coefficients_ = AddNullInput();
124         forget_layer_norm_coefficients_ = AddNullInput();
125         cell_layer_norm_coefficients_ = AddNullInput();
126         output_layer_norm_coefficients_ = AddNullInput();
127       }
128     }
129 
130     output_ = AddOutput({TensorType_FLOAT32, {n_batch, n_output}});
131 
132     // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
133     // default 0.
134     SetBuiltinOp(
135         BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
136         CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
137                           /*cell_clip=*/0.0f, /*proj_clip=*/0.0f,
138                           LSTMKernelType_FULL, asymmetric_quantize_inputs)
139             .Union());
140 
141     // Input shapes are already set up, no need to pass them again.
142     BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
143                      /*allow_fp32_relax_to_fp16=*/false,
144                      /*apply_delegate=*/false);
145   }
146 
SetInputToInputWeights(const std::vector<float> & f)147   void SetInputToInputWeights(const std::vector<float>& f) {
148     SetWeights(input_to_input_weights_, f);
149   }
150 
SetInputToForgetWeights(const std::vector<float> & f)151   void SetInputToForgetWeights(const std::vector<float>& f) {
152     SetWeights(input_to_forget_weights_, f);
153   }
154 
SetInputToCellWeights(const std::vector<float> & f)155   void SetInputToCellWeights(const std::vector<float>& f) {
156     SetWeights(input_to_cell_weights_, f);
157   }
158 
SetInputToOutputWeights(const std::vector<float> & f)159   void SetInputToOutputWeights(const std::vector<float>& f) {
160     SetWeights(input_to_output_weights_, f);
161   }
162 
SetRecurrentToInputWeights(const std::vector<float> & f)163   void SetRecurrentToInputWeights(const std::vector<float>& f) {
164     SetWeights(recurrent_to_input_weights_, f);
165   }
166 
SetRecurrentToForgetWeights(const std::vector<float> & f)167   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
168     SetWeights(recurrent_to_forget_weights_, f);
169   }
170 
SetRecurrentToCellWeights(const std::vector<float> & f)171   void SetRecurrentToCellWeights(const std::vector<float>& f) {
172     SetWeights(recurrent_to_cell_weights_, f);
173   }
174 
SetRecurrentToOutputWeights(const std::vector<float> & f)175   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
176     SetWeights(recurrent_to_output_weights_, f);
177   }
178 
SetCellToInputWeights(const std::vector<float> & f)179   void SetCellToInputWeights(const std::vector<float>& f) {
180     SetWeights(cell_to_input_weights_, f);
181   }
182 
SetCellToForgetWeights(const std::vector<float> & f)183   void SetCellToForgetWeights(const std::vector<float>& f) {
184     SetWeights(cell_to_forget_weights_, f);
185   }
186 
SetCellToOutputWeights(const std::vector<float> & f)187   void SetCellToOutputWeights(const std::vector<float>& f) {
188     SetWeights(cell_to_output_weights_, f);
189   }
190 
SetInputLayerNormCoefficients(const std::vector<float> & f)191   void SetInputLayerNormCoefficients(const std::vector<float>& f) {
192     PopulateTensor(input_layer_norm_coefficients_, f);
193   }
194 
SetForgetLayerNormCoefficients(const std::vector<float> & f)195   void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
196     PopulateTensor(forget_layer_norm_coefficients_, f);
197   }
198 
SetCellLayerNormCoefficients(const std::vector<float> & f)199   void SetCellLayerNormCoefficients(const std::vector<float>& f) {
200     PopulateTensor(cell_layer_norm_coefficients_, f);
201   }
202 
SetOutputLayerNormCoefficients(const std::vector<float> & f)203   void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
204     PopulateTensor(output_layer_norm_coefficients_, f);
205   }
206 
SetInputGateBias(const std::vector<float> & f)207   void SetInputGateBias(const std::vector<float>& f) {
208     PopulateTensor(input_gate_bias_, f);
209   }
210 
SetForgetGateBias(const std::vector<float> & f)211   void SetForgetGateBias(const std::vector<float>& f) {
212     PopulateTensor(forget_gate_bias_, f);
213   }
214 
SetCellBias(const std::vector<float> & f)215   void SetCellBias(const std::vector<float>& f) {
216     PopulateTensor(cell_gate_bias_, f);
217   }
218 
SetOutputGateBias(const std::vector<float> & f)219   void SetOutputGateBias(const std::vector<float>& f) {
220     PopulateTensor(output_gate_bias_, f);
221   }
222 
SetProjectionWeights(const std::vector<float> & f)223   void SetProjectionWeights(const std::vector<float>& f) {
224     SetWeights(projection_weights_, f);
225   }
226 
SetProjectionBias(const std::vector<float> & f)227   void SetProjectionBias(const std::vector<float>& f) {
228     PopulateTensor(projection_bias_, f);
229   }
230 
SetInput(int offset,const float * begin,const float * end)231   void SetInput(int offset, const float* begin, const float* end) {
232     SingleOpModel::PopulateTensor(input_, offset, const_cast<float*>(begin),
233                                   const_cast<float*>(end));
234   }
235 
GetOutput()236   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
237 
num_inputs()238   int num_inputs() { return n_input_; }
num_outputs()239   int num_outputs() { return n_output_; }
num_batches()240   int num_batches() { return n_batch_; }
241 
242  protected:
243   int input_;
244   int input_to_input_weights_;
245   int input_to_forget_weights_;
246   int input_to_cell_weights_;
247   int input_to_output_weights_;
248 
249   int recurrent_to_input_weights_;
250   int recurrent_to_forget_weights_;
251   int recurrent_to_cell_weights_;
252   int recurrent_to_output_weights_;
253 
254   int cell_to_input_weights_;
255   int cell_to_forget_weights_;
256   int cell_to_output_weights_;
257 
258   int input_layer_norm_coefficients_ = kTfLiteOptionalTensor;
259   int forget_layer_norm_coefficients_ = kTfLiteOptionalTensor;
260   int cell_layer_norm_coefficients_ = kTfLiteOptionalTensor;
261   int output_layer_norm_coefficients_ = kTfLiteOptionalTensor;
262 
263   int input_gate_bias_;
264   int forget_gate_bias_;
265   int cell_gate_bias_;
266   int output_gate_bias_;
267 
268   int projection_weights_;
269   int projection_bias_;
270 
271   int output_;
272 
273   int n_input_;
274   int n_output_;
275   int n_batch_;
276 
277  private:
PopulateTensor(int index,const std::vector<float> & data)278   void PopulateTensor(int index, const std::vector<float>& data) {
279     // Nothing to do if tensor is an optional input or if data vector is empty.
280     if ((index == kTfLiteOptionalTensor) || data.empty()) return;
281     SingleOpModel::PopulateTensor(index, data);
282   }
283 
SetWeights(int index,const std::vector<float> & data)284   void SetWeights(int index, const std::vector<float>& data) {
285     if (data.empty()) return;
286     if (index == kTfLiteOptionalTensor) return;
287     switch (weight_type_) {
288       case TensorType_FLOAT32:
289         PopulateTensor(index, data);
290         break;
291       case TensorType_UINT8:
292         SymmetricQuantizeAndPopulate(index, data);
293         break;
294       case TensorType_INT8:
295         SignedSymmetricQuantizeAndPopulate(index, data);
296         break;
297       default:
298         GTEST_FAIL() << "Type not supported: " << weight_type_;
299         break;
300     }
301   }
302 
303   const TensorType weight_type_;
304 };
305 
306 // Parameters:
307 // std::get<0>(GetParam()) => weight_type
308 // std::get<1>(GetParam()) => model_has_legacy_20_inputs
309 // std::get<2>(GetParam()) => asymmetric_quantize_inputs
310 class LstmOpTest
311     : public ::testing::TestWithParam<std::tuple<TensorType, bool, bool>> {
312  protected:
313   // Weights of the LSTM model. Some are optional.
314   std::vector<float> input_to_input_weights_;
315   std::vector<float> input_to_cell_weights_;
316   std::vector<float> input_to_forget_weights_;
317   std::vector<float> input_to_output_weights_;
318   std::vector<float> input_gate_bias_;
319   std::vector<float> cell_gate_bias_;
320   std::vector<float> forget_gate_bias_;
321   std::vector<float> output_gate_bias_;
322   std::vector<float> recurrent_to_input_weights_;
323   std::vector<float> recurrent_to_cell_weights_;
324   std::vector<float> recurrent_to_forget_weights_;
325   std::vector<float> recurrent_to_output_weights_;
326   std::vector<float> cell_to_input_weights_;
327   std::vector<float> cell_to_forget_weights_;
328   std::vector<float> cell_to_output_weights_;
329   std::vector<float> projection_weights_;
330   std::vector<float> input_layer_norm_coefficients_;
331   std::vector<float> forget_layer_norm_coefficients_;
332   std::vector<float> cell_layer_norm_coefficients_;
333   std::vector<float> output_layer_norm_coefficients_;
334 
335   // LSTM input is stored as num_steps * num_batch * num_inputs vector.
336   std::vector<std::vector<std::vector<float>>> lstm_input_;
337   // LSTM output is stored as num_steps * num_batch * num_outputs vector.
338   std::vector<std::vector<std::vector<float>>> lstm_golden_output_;
339 
340   // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(LSTMOpModel * lstm,float tolerance)341   void VerifyGoldens(LSTMOpModel* lstm, float tolerance) {
342     // The delegate, if used, needs to know the scales and zero-points of
343     // quantized tensors, which are computed dynamically when weights are set,
344     // so weights have to be set before applying the delegate.
345     SetAllWeightsAndBiases(lstm);
346     lstm->ApplyDelegate();
347 
348     const int num_inputs = lstm->num_inputs();
349     const int num_outputs = lstm->num_outputs();
350     const int num_batches = lstm->num_batches();
351 
352     ASSERT_EQ(lstm_input_.size(), lstm_golden_output_.size());
353     const int num_steps = lstm_input_.size();
354 
355     for (int i = 0; i < num_steps; ++i) {
356       ASSERT_EQ(num_batches, lstm_input_[i].size());
357       for (int b = 0; b < num_batches; ++b) {
358         ASSERT_EQ(num_inputs, lstm_input_[i][b].size());
359         const float* batch_start = lstm_input_[i][b].data();
360         const float* batch_end = batch_start + num_inputs;
361         lstm->SetInput(b * num_inputs, batch_start, batch_end);
362       }
363 
364       ASSERT_EQ(lstm->Invoke(), kTfLiteOk);
365 
366       std::vector<float> expected;
367       ASSERT_EQ(num_batches, lstm_golden_output_[i].size());
368       for (int b = 0; b < num_batches; ++b) {
369         ASSERT_EQ(num_outputs, lstm_golden_output_[i][b].size());
370         const float* batch_start = lstm_golden_output_[i][b].data();
371         const float* batch_end = batch_start + num_outputs;
372         expected.insert(expected.end(), batch_start, batch_end);
373       }
374 
375       EXPECT_THAT(lstm->GetOutput(),
376                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
377     }
378   }
379 
380   // Sets all weights and biases that have been defined by test. The test can
381   // define only a subset of all those vectors, and only the ones that have been
382   // defined will be set.
SetAllWeightsAndBiases(LSTMOpModel * lstm)383   void SetAllWeightsAndBiases(LSTMOpModel* lstm) {
384     lstm->SetInputToInputWeights(input_to_input_weights_);
385     lstm->SetInputToCellWeights(input_to_cell_weights_);
386     lstm->SetInputToForgetWeights(input_to_forget_weights_);
387     lstm->SetInputToOutputWeights(input_to_output_weights_);
388 
389     lstm->SetInputGateBias(input_gate_bias_);
390     lstm->SetCellBias(cell_gate_bias_);
391     lstm->SetForgetGateBias(forget_gate_bias_);
392     lstm->SetOutputGateBias(output_gate_bias_);
393 
394     lstm->SetRecurrentToInputWeights(recurrent_to_input_weights_);
395     lstm->SetRecurrentToCellWeights(recurrent_to_cell_weights_);
396     lstm->SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
397     lstm->SetRecurrentToOutputWeights(recurrent_to_output_weights_);
398 
399     lstm->SetCellToInputWeights(cell_to_input_weights_);
400     lstm->SetCellToForgetWeights(cell_to_forget_weights_);
401     lstm->SetCellToOutputWeights(cell_to_output_weights_);
402 
403     lstm->SetProjectionWeights(projection_weights_);
404 
405     lstm->SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
406     lstm->SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
407     lstm->SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
408     lstm->SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
409   }
410 };
411 
TEST_P(LstmOpTest,NoCifg_NoPeephole_NoProjection_NoLayerNorm)412 TEST_P(LstmOpTest, NoCifg_NoPeephole_NoProjection_NoLayerNorm) {
413   const int n_batch = 1;
414   const int n_input = 2;
415   // n_cell and n_output have the same size when there is no projection.
416   const int n_cell = 4;
417   const int n_output = 4;
418 
419   TensorType weight_type;
420   bool model_has_legacy_20_inputs;
421   bool asymmetric_quantize_inputs;
422   std::tie(weight_type, model_has_legacy_20_inputs,
423            asymmetric_quantize_inputs) = GetParam();
424 
425   // TODO(b/158205028): Fix this test if using NN-API.
426   if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
427     return;
428   }
429 
430   input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,  -0.34550029,
431                              0.04266912,  -0.15680569, -0.34856534, 0.43890524};
432   input_to_cell_weights_ = {-0.50013041, 0.1370284,  0.11810488, 0.2013163,
433                             -0.20583314, 0.44344562, 0.22077113, -0.29909778};
434   input_to_forget_weights_ = {0.09701663,  0.20334584,  -0.50592935,
435                               -0.31343272, -0.40032279, 0.44781327,
436                               0.01387155,  -0.35593212};
437   input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, 0.40525138,
438                               0.44272184,  0.03897077,  -0.1556896, 0.19487578};
439   input_gate_bias_ = {0., 0., 0., 0.};
440   cell_gate_bias_ = {0., 0., 0., 0.};
441   forget_gate_bias_ = {1., 1., 1., 1.};
442   output_gate_bias_ = {0., 0., 0., 0.};
443 
444   recurrent_to_input_weights_ = {
445       -0.0063535,  -0.2042388,  0.31454784,  -0.35746509,
446       0.28902304,  0.08183324,  -0.16555229, 0.02286911,
447       -0.13566875, 0.03034258,  0.48091322,  -0.12528998,
448       0.24077177,  -0.51332325, -0.33502164, 0.10629296};
449 
450   recurrent_to_cell_weights_ = {
451       -0.3407414,  0.24443203,  -0.2078532,  0.26320225,
452       0.05695659,  -0.00123841, -0.4744786,  -0.35869038,
453       -0.06418842, -0.13502428, -0.501764,   0.22830659,
454       -0.46367589, 0.26016325,  -0.03894562, -0.16368064};
455 
456   recurrent_to_forget_weights_ = {
457       -0.48684245, -0.06655136, 0.42224967,  0.2112639,
458       0.27654213,  0.20864892,  -0.07646349, 0.45877004,
459       0.00141793,  -0.14609534, 0.36447752,  0.09196436,
460       0.28053468,  0.01560611,  -0.20127171, -0.01140004};
461 
462   recurrent_to_output_weights_ = {
463       0.43385774,  -0.17194885, 0.2718237,  0.09215671,
464       0.24107647,  -0.39835793, 0.18212086, 0.01301402,
465       0.48572797,  -0.50656658, 0.20047462, -0.20607421,
466       -0.51818722, -0.15390486, 0.0468148,  0.39922136};
467 
468   // num_steps * num_batch * num_inputs
469   lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}};
470   // num_steps * num_batch * num_outputs
471   lstm_golden_output_ = {{{-0.02973187, 0.1229473, 0.20885126, -0.15358765}},
472                          {{-0.03716109, 0.12507336, 0.41193449, -0.20860538}},
473                          {{-0.15053082, 0.09120187, 0.24278517, -0.12222792}}};
474 
475   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
476                    /*use_cifg=*/false, /*use_peephole=*/false,
477                    /*use_projection_weights=*/false,
478                    /*use_projection_bias=*/false, weight_type,
479                    model_has_legacy_20_inputs,
480                    /*is_layer_norm=*/false, asymmetric_quantize_inputs);
481 
482   static const auto* tolerance_per_type =
483       new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
484                                       {TensorType_UINT8, 0.0157651f},
485                                       {TensorType_INT8, 0.0157651f}};
486   VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
487 }
488 
TEST_P(LstmOpTest,Cifg_Peephole_NoProjection_NoLayerNorm)489 TEST_P(LstmOpTest, Cifg_Peephole_NoProjection_NoLayerNorm) {
490   const int n_batch = 1;
491   const int n_input = 2;
492   // n_cell and n_output have the same size when there is no projection.
493   const int n_cell = 4;
494   const int n_output = 4;
495 
496   TensorType weight_type;
497   bool model_has_legacy_20_inputs;
498   bool asymmetric_quantize_inputs;
499   std::tie(weight_type, model_has_legacy_20_inputs,
500            asymmetric_quantize_inputs) = GetParam();
501 
502   // TODO(b/158205028): Fix this test if using NN-API.
503   if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
504     return;
505   }
506 
507   input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, 0.05100781,
508                             0.04717243,  0.48944736,  -0.38535351, -0.17212132};
509 
510   input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, -0.3633365,
511                               -0.22755712, 0.28253698,  0.24407166, 0.33826375};
512 
513   input_to_output_weights_ = {0.10725588,  -0.02335852, -0.55932593,
514                               -0.09426838, -0.44257352, 0.54939759,
515                               0.01533556,  0.42751634};
516   cell_gate_bias_ = {0., 0., 0., 0.};
517   forget_gate_bias_ = {1., 1., 1., 1.};
518   output_gate_bias_ = {0., 0., 0., 0.};
519 
520   recurrent_to_cell_weights_ = {
521       0.54066205,  -0.32668582, -0.43562764, -0.56094903,
522       0.42957711,  0.01841056,  -0.32764608, -0.33027974,
523       -0.10826075, 0.20675004,  0.19069612,  -0.03026325,
524       -0.54532051, 0.33003211,  0.44901288,  0.21193194};
525 
526   recurrent_to_forget_weights_ = {
527       -0.13832897, -0.0515101,  -0.2359007, -0.16661474,
528       -0.14340827, 0.36986142,  0.23414481, 0.55899,
529       0.10798943,  -0.41174671, 0.17751795, -0.34484994,
530       -0.35874045, -0.11352962, 0.27268326, 0.54058349};
531 
532   recurrent_to_output_weights_ = {
533       0.41613156, 0.42610586,  -0.16495961, -0.5663873,
534       0.30579174, -0.05115908, -0.33941799, 0.23364776,
535       0.11178309, 0.09481031,  -0.26424935, 0.46261835,
536       0.50248802, 0.26114327,  -0.43736315, 0.33149987};
537 
538   cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, 0.31544167};
539   cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, -0.77109635};
540 
541   lstm_input_ = {{{2., 3.}}, {{3., 4.}}, {{1., 1.}}};
542   lstm_golden_output_ = {{{-0.36444446, -0.00352185, 0.12886585, -0.05163646}},
543                          {{-0.42312205, -0.01218222, 0.24201041, -0.08124574}},
544                          {{-0.358325, -0.04621704, 0.21641694, -0.06471302}}};
545 
546   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
547                    /*use_cifg=*/true, /*use_peephole=*/true,
548                    /*use_projection_weights=*/false,
549                    /*use_projection_bias=*/false, weight_type,
550                    model_has_legacy_20_inputs, /*is_layer_norm=*/false,
551                    asymmetric_quantize_inputs);
552 
553   static const auto* tolerance_per_type =
554       new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
555                                       {TensorType_UINT8, 0.03573f},
556                                       {TensorType_INT8, 0.03573f}};
557   VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
558 }
559 
TEST_P(LstmOpTest,NoCifg_Peephole_Projection_NoLayerNorm)560 TEST_P(LstmOpTest, NoCifg_Peephole_Projection_NoLayerNorm) {
561   const int n_batch = 2;
562   const int n_input = 5;
563   const int n_cell = 20;
564   const int n_output = 16;
565 
566   TensorType weight_type;
567   bool model_has_legacy_20_inputs;
568   bool asymmetric_quantize_inputs;
569   std::tie(weight_type, model_has_legacy_20_inputs,
570            asymmetric_quantize_inputs) = GetParam();
571 
572   // TODO(b/158205028): Fix this test if using NN-API.
573   if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
574     return;
575   }
576 
577   input_to_input_weights_ = {
578       0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
579       0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
580       -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
581       -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
582       -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
583       -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
584       -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
585       0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
586       0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
587       0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
588       -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
589       0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
590       -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
591       -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
592       -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
593       0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
594       -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
595       -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
596       -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
597       -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677};
598 
599   input_to_forget_weights_ = {
600       -0.0018401089, -0.004852237,  0.03698424,   0.014181704,   0.028273236,
601       -0.016726194,  -0.05249759,   -0.10204261,  0.00861066,    -0.040979505,
602       -0.009899187,  0.01923892,    -0.028177269, -0.08535103,   -0.14585495,
603       0.10662567,    -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
604       0.0030784295,  0.076784775,   0.07463696,   0.094531395,   0.0814421,
605       -0.12257899,   -0.033945758,  -0.031303465, 0.045630626,   0.06843887,
606       -0.13492945,   -0.012480007,  -0.0811829,   -0.07224499,   -0.09628791,
607       0.045100946,   0.0012300825,  0.013964662,  0.099372394,   0.02543059,
608       0.06958324,    0.034257296,   0.0482646,    0.06267997,    0.052625068,
609       0.12784666,    0.07077897,    0.025725935,  0.04165009,    0.07241905,
610       0.018668644,   -0.037377294,  -0.06277783,  -0.08833636,   -0.040120605,
611       -0.011405586,  -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
612       0.05483423,    0.11449111,    0.11289652,   0.10939839,    0.13396506,
613       -0.08402166,   -0.01901462,   -0.044678304, -0.07720565,   0.014350063,
614       -0.11757958,   -0.0652038,    -0.08185733,  -0.076754324,  -0.092614375,
615       0.10405491,    0.052960336,   0.035755895,  0.035839386,   -0.012540553,
616       0.036881298,   0.02913376,    0.03420159,   0.05448447,    -0.054523353,
617       0.02582715,    0.02327355,    -0.011857179, -0.0011980024, -0.034641717,
618       -0.026125094,  -0.17582615,   -0.15923657,  -0.27486774,   -0.0006143371,
619       0.0001771948,  -8.470171e-05, 0.02651807,   0.045790765,   0.06956496};
620 
621   input_to_cell_weights_ = {
622       -0.04580283,  -0.09549462,   -0.032418985,  -0.06454633,   -0.043528453,
623       0.043018587,  -0.049152344,  -0.12418144,   -0.078985475,  -0.07596889,
624       0.019484362,  -0.11434962,   -0.0074034138, -0.06314844,   -0.092981495,
625       0.0062155537, -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
626       0.10665918,   -0.032036792,  -0.08505916,   -0.10843358,   -0.13002433,
627       -0.036816437, -0.02130134,   -0.016518239,  0.0047691227,  -0.0025825808,
628       0.066017866,  0.029991534,   -0.10652836,   -0.1037554,    -0.13056071,
629       -0.03266643,  -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
630       -0.025174323, 0.0396852,     0.081777506,   0.06157468,    0.10210095,
631       -0.009658194, 0.046511717,   0.03603906,    0.0069369148,  0.015960095,
632       -0.06507666,  0.09551598,    0.053568836,   0.06408714,    0.12835667,
633       -0.008714329, -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
634       -0.029227901, 0.1164364,     -0.08560263,   0.09941786,    -0.036999565,
635       -0.028842626, -0.0033637602, -0.017012902,  -0.09720865,   -0.11193351,
636       -0.029155117, -0.017936034,  -0.009768936,  -0.04223324,   -0.036159635,
637       0.06505112,   -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
638       0.05453865,   0.091149814,   0.06387331,    0.007518393,   0.055960953,
639       0.069779344,  0.046411168,   0.10509911,    0.07463894,    0.0075130584,
640       0.012850982,  0.04555431,    0.056955688,   0.06555285,    0.050801456,
641       -0.009862683, 0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042};
642 
643   input_to_output_weights_ = {
644       -0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
645       -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
646       0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
647       -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
648       -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
649       0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
650       -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
651       -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
652       -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
653       -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
654       0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
655       0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
656       0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
657       -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
658       0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
659       0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
660       -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
661       0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
662       -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
663       -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956};
664 
665   input_gate_bias_ = {0.02234832,   0.14757581,  0.18176508,  0.10380666,
666                       0.053110216,  -0.06928846, -0.13942584, -0.11816189,
667                       0.19483899,   0.03652339,  -0.10250295, 0.036714908,
668                       -0.18426876,  0.036065217, 0.21810818,  0.02383196,
669                       -0.043370757, 0.08690144,  -0.04444982, 0.00030581196};
670 
671   forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
672                        0.11098921,  0.15378423,   0.09263801,  0.09790885,
673                        0.09508917,  0.061199076,  0.07665568,  -0.015443159,
674                        -0.03499149, 0.046190713,  0.08895977,  0.10899629,
675                        0.40694186,  0.06030037,   0.012413437, -0.06108739};
676 
677   cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
678                      -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
679                      -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
680                      -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
681                      0.016178843,  0.1749513,    0.13975595,   0.92058027};
682 
683   output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469,   0.12648113,
684                        0.027195795, 0.35373217,    -0.018957434, 0.008907322,
685                        -0.0762701,  0.12018895,    0.04216877,   0.0022856654,
686                        0.040952638, 0.3147856,     0.08225149,   -0.057416286,
687                        -0.14995944, -0.008040261,  0.13208859,   0.029760877};
688 
689   recurrent_to_input_weights_ = {
690       -0.001374326,   -0.078856036,   0.10672688,    0.029162422,
691       -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
692       -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
693       -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
694       0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
695       0.08981,        -0.045407712,   0.08682226,    -0.06867011,
696       -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
697       0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
698       -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
699       0.009352075,    0.22920375,     0.0016303885,  0.11583097,
700       -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
701       0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
702       -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
703       0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
704       -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
705       -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
706       -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
707       -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
708       -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
709       0.01068115,     0.032956902,    0.022433773,   0.0026891115,
710       0.08944216,     -0.0685835,     0.010513544,   0.07228705,
711       0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
712       0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
713       0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
714       -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
715       -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
716       0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
717       -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
718       -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
719       -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
720       -0.017142897,   0.03312627,     0.009205989,   0.024138335,
721       -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
722       -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
723       0.0365468,      0.07590991,     0.08838724,    0.021681072,
724       -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
725       0.023646897,    -0.095322326,   0.02233014,    0.09756986,
726       -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
727       -0.09801813,    0.019894179,    0.08502348,    0.004032281,
728       0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
729       -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
730       -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
731       0.010889619,    0.0047078193,   0.038385306,   0.08540671,
732       -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
733       0.015963363,    0.00871737,     0.060130805,   0.028611384,
734       0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
735       0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
736       0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
737       0.019899689,    0.006106124,    -0.027092824,  0.0786356,
738       0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
739       -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
740       -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
741       -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
742       -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
743       -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
744       0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
745       0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
746       -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
747       0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
748       0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
749       0.058618143,    -0.08598433,    0.00972939,    0.023867095,
750       -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
751       -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
752       0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
753       -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
754       -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
755       0.06358255,     0.18531723,     0.07759293,    0.12006465,
756       0.1305557,      0.058638252,    -0.03393652,   0.09622831,
757       -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
758       -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
759       0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
760       0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
761       0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
762       0.08184801,     -0.019164372,   0.06791302,    0.034257166,
763       -0.10307039,    0.021943003,    0.046745934,   0.0790918,
764       -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
765       -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
766       -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
767       0.026351685,    0.012641483,    0.07466548,    0.044301085,
768       -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
769       -0.04106223,    -0.028126027,   0.028473156,   0.10467447};
770 
771   recurrent_to_cell_weights_ = {
772       -0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
773       0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
774       0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
775       -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
776       0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
777       0.08089997,     0.05143358,    0.038261272,   0.03339287,
778       -0.027673481,   0.044746667,   0.028349208,   0.020090483,
779       -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
780       -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
781       -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
782       0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
783       -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
784       -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
785       0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
786       0.010868644,    -0.031489216,  0.09525667,    0.013939797,
787       0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
788       -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
789       0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
790       0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
791       -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
792       0.02786344,     -0.014179351,  0.005264273,   0.14376344,
793       0.015983658,    0.03406988,    -0.06939408,   0.040699873,
794       0.02111075,     0.09669095,    0.041345075,   -0.08316494,
795       -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
796       0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
797       -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
798       0.06760663,     -0.027437469,  0.07216407,    0.06977076,
799       -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
800       0.043184172,    -0.037189785,  0.10420091,    0.00882477,
801       -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
802       0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
803       0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
804       -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
805       0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
806       -0.008264958,   0.042035464,   0.05891794,    0.029673764,
807       0.0063542654,   0.044788733,   0.054816857,   0.062257513,
808       -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
809       -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
810       -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
811       -0.007376126,   0.003533447,   0.006570588,   0.056037236,
812       0.12436656,     0.051817212,   0.028532185,   -0.08686856,
813       0.11868599,     0.07663395,    -0.07323171,   0.03463402,
814       -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
815       0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
816       0.023029093,    0.086124025,   0.006445803,   -0.03496501,
817       0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
818       -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
819       0.09465633,     0.008115513,   -0.02171956,   0.08304309,
820       0.071401566,    0.019622514,   0.032163795,   -0.004167056,
821       0.02295182,     0.030739572,   0.056506045,   0.004612461,
822       0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
823       -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
824       0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
825       -0.0329582,     0.07922767,    0.029322514,   0.026405897,
826       0.04207835,     -0.07073373,   0.063781224,   0.0859677,
827       -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
828       -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
829       -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
830       -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
831       0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
832       0.15978073,     0.10185836,    0.10298046,    -0.015476589,
833       -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
834       -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
835       -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
836       -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
837       -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
838       0.012962922,    -0.031234352,  0.07029052,    0.016418684,
839       0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
840       -0.054761406,   0.029065743,   0.052404847,   0.020238016,
841       0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
842       0.06262858,     0.009184685,   0.020785125,   -0.043904778,
843       -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
844       -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
845       0.09232601,     -0.035886683,  0.06000002,    0.05229691,
846       -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
847       -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
848       0.031502828,    0.036232427,   -0.031581745,  0.023051167,
849       -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
850       -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
851       -0.008799762,   0.056595087,   0.0022273948,  0.055752404};
852 
853   recurrent_to_forget_weights_ = {
854       -0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
855       0.14811787,    0.10826372,    0.09471067,     0.03987225,
856       -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
857       0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
858       0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
859       -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
860       -0.06193199,   0.055729095,   0.03736828,     0.020123724,
861       0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
862       -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
863       -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
864       0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
865       -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
866       -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
867       -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
868       0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
869       0.013454138,   0.028934088,   0.01685226,     -0.086110644,
870       -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
871       0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
872       0.03761666,    0.008096139,   -0.014454086,   0.014361001,
873       -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
874       -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
875       0.060212336,   0.055259194,   0.06974018,     0.049454916,
876       -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
877       0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
878       -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
879       0.0042065294,  0.03881498,    0.019844765,    0.041858196,
880       -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
881       0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
882       0.012321099,   0.082840554,   -0.029899208,   0.044217527,
883       0.059855383,   0.07711018,    -0.045319796,   0.0948846,
884       -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
885       -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
886       -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
887       0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
888       0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
889       0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
890       0.052958444,   0.07558703,    0.04817258,     0.044462286,
891       -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
892       0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
893       0.024734668,   0.024614193,   -0.042046934,   0.09597743,
894       -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
895       -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
896       -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
897       0.04383914,    -0.046476185,  0.028658995,    0.060410924,
898       0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
899       0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
900       0.015898481,   0.021362653,   -0.030262267,   0.016587038,
901       -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
902       -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
903       0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
904       -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
905       -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
906       -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
907       -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
908       0.15443139,    0.07684145,    0.036571592,    -0.035900835,
909       -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
910       -0.03858649,   0.01849943,    0.13872518,     0.01503974,
911       0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
912       -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
913       0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
914       0.05866852,    0.023947537,   -0.09445152,    0.035450947,
915       0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
916       0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
917       0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
918       0.051808182,   0.05875331,    -0.04536488,    0.001626336,
919       -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
920       0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
921       -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
922       -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
923       0.11475477,    -0.023854522,  0.10071741,     0.0686208,
924       -0.014250481,  0.034261297,   0.047418304,    0.08562733,
925       -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
926       0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
927       0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
928       0.014410365,   0.020995233,   0.17040324,     0.11511526,
929       0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
930       -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
931       -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
932       0.007076659,   0.10964551,    0.0409152,      0.008275321,
933       -0.07283536,   0.07937492,    0.04192024,     -0.1075027};
934 
935   recurrent_to_output_weights_ = {
936       0.025825322,   -0.05813119,  0.09495884,   -0.045984812,   -0.01255415,
937       -0.0026479573, -0.08196161,  -0.054914974, -0.0046604523,  -0.029587349,
938       -0.044576716,  -0.07480124,  -0.082868785, 0.023254942,    0.027502948,
939       -0.0039728214, -0.08683098,  -0.08116779,  -0.014675607,   -0.037924774,
940       -0.023314456,  -0.007401714, -0.09255757,  0.029460307,    -0.08829125,
941       -0.005139627,  -0.08989442,  -0.0555066,   0.13596267,     -0.025062224,
942       -0.048351806,  -0.03850004,  0.07266485,   -0.022414139,   0.05940088,
943       0.075114764,   0.09597592,   -0.010211725, -0.0049794707,  -0.011523867,
944       -0.025980417,  0.072999895,  0.11091378,   -0.081685916,   0.014416728,
945       0.043229222,   0.034178585,  -0.07530371,  0.035837382,    -0.085607,
946       -0.007721233,  -0.03287832,  -0.043848954, -0.06404588,    -0.06632928,
947       -0.073643476,  0.008214239,  -0.045984086, 0.039764922,    0.03474462,
948       0.060612556,   -0.080590084, 0.049127717,  0.04151091,     -0.030063879,
949       0.008801774,   -0.023021035, -0.019558564, 0.05158114,     -0.010947698,
950       -0.011825728,  0.0075720972, 0.0699727,    -0.0039981045,  0.069350146,
951       0.08799282,    0.016156472,  0.035502106,  0.11695009,     0.006217345,
952       0.13392477,    -0.037875112, 0.025745004,  0.08940699,     -0.00924166,
953       0.0046702605,  -0.036598757, -0.08811812,  0.10522024,     -0.032441203,
954       0.008176899,   -0.04454919,  0.07058152,   0.0067963637,   0.039206743,
955       0.03259838,    0.03725492,   -0.09515802,  0.013326398,    -0.052055415,
956       -0.025676316,  0.03198509,   -0.015951829, -0.058556724,   0.036879618,
957       0.043357447,   0.028362012,  -0.05908629,  0.0059240665,   -0.04995891,
958       -0.019187413,  0.0276265,    -0.01628143,  0.0025863599,   0.08800015,
959       0.035250366,   -0.022165963, -0.07328642,  -0.009415526,   -0.07455109,
960       0.11690406,    0.0363299,    0.07411125,   0.042103454,    -0.009660886,
961       0.019076364,   0.018299393,  -0.046004917, 0.08891175,     0.0431396,
962       -0.026327137,  -0.051502608, 0.08979574,   -0.051670972,   0.04940282,
963       -0.07491107,   -0.021240504, 0.022596184,  -0.034280192,   0.060163025,
964       -0.058211457,  -0.051837247, -0.01349775,  -0.04639988,    -0.035936575,
965       -0.011681591,  0.064818054,  0.0073146066, -0.021745546,   -0.043124277,
966       -0.06471268,   -0.07053354,  -0.029321948, -0.05330136,    0.016933719,
967       -0.053782392,  0.13747959,   -0.1361751,   -0.11569455,    0.0033329215,
968       0.05693899,    -0.053219706, 0.063698,     0.07977434,     -0.07924483,
969       0.06936997,    0.0034815092, -0.007305279, -0.037325785,   -0.07251102,
970       -0.033633437,  -0.08677009,  0.091591336,  -0.14165086,    0.021752775,
971       0.019683983,   0.0011612234, -0.058154266, 0.049996935,    0.0288841,
972       -0.0024567875, -0.14345716,  0.010955264,  -0.10234828,    0.1183656,
973       -0.0010731248, -0.023590032, -0.072285876, -0.0724771,     -0.026382286,
974       -0.0014920527, 0.042667855,  0.0018776858, 0.02986552,     0.009814309,
975       0.0733756,     0.12289186,   0.018043943,  -0.0458958,     0.049412545,
976       0.033632483,   0.05495232,   0.036686596,  -0.013781798,   -0.010036754,
977       0.02576849,    -0.08307328,  0.010112348,  0.042521734,    -0.05869831,
978       -0.071689695,  0.03876447,   -0.13275425,  -0.0352966,     -0.023077697,
979       0.10285965,    0.084736146,  0.15568255,   -0.00040734606, 0.027835453,
980       -0.10292561,   -0.032401145, 0.10053256,   -0.026142767,   -0.08271222,
981       -0.0030240538, -0.016368777, 0.1070414,    0.042672627,    0.013456989,
982       -0.0437609,    -0.022309763, 0.11576483,   0.04108048,     0.061026827,
983       -0.0190714,    -0.0869359,   0.037901703,  0.0610107,      0.07202949,
984       0.01675338,    0.086139716,  -0.08795751,  -0.014898893,   -0.023771819,
985       -0.01965048,   0.007955471,  -0.043740474, 0.03346837,     -0.10549954,
986       0.090567775,   0.042013682,  -0.03176985,  0.12569028,     -0.02421228,
987       -0.029526481,  0.023851605,  0.031539805,  0.05292009,     -0.02344001,
988       -0.07811758,   -0.08834428,  0.10094801,   0.16594367,     -0.06861939,
989       -0.021256343,  -0.041093912, -0.06669611,  0.035498552,    0.021757556,
990       -0.09302526,   -0.015403468, -0.06614931,  -0.051798206,   -0.013874718,
991       0.03630673,    0.010412845,  -0.08077351,  0.046185967,    0.0035662893,
992       0.03541868,    -0.094149634, -0.034814864, 0.003128424,    -0.020674974,
993       -0.03944324,   -0.008110165, -0.11113267,  0.08484226,     0.043586485,
994       0.040582247,   0.0968012,    -0.065249965, -0.028036479,   0.0050708856,
995       0.0017462453,  0.0326779,    0.041296225,  0.09164146,     -0.047743853,
996       -0.015952192,  -0.034451712, 0.084197424,  -0.05347844,    -0.11768019,
997       0.085926116,   -0.08251791,  -0.045081906, 0.0948852,      0.068401024,
998       0.024856757,   0.06978981,   -0.057309967, -0.012775832,   -0.0032452994,
999       0.01977615,    -0.041040014, -0.024264973, 0.063464895,    0.05431621,
1000   };
1001 
1002   cell_to_input_weights_ = {
1003       0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
1004       -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
1005       -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
1006       0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175};
1007 
1008   cell_to_forget_weights_ = {
1009       -0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
1010       -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
1011       -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
1012       0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355};
1013 
1014   cell_to_output_weights_ = {0.08286371,  -0.08261836, -0.51210177, 0.002913762,
1015                              0.17764764,  -0.5495371,  -0.08460716, -0.24552552,
1016                              0.030037103, 0.04123544,  -0.11940523, 0.007358328,
1017                              0.1890978,   0.4833202,   -0.34441817, 0.36312827,
1018                              -0.26375428, 0.1457655,   -0.19724406, 0.15548733};
1019 
1020   projection_weights_ = {
1021       -0.009802181,  0.09401916,    0.0717386,     -0.13895074,  0.09641832,
1022       0.060420845,   0.08539281,    0.054285463,   0.061395317,  0.034448683,
1023       -0.042991187,  0.019801661,   -0.16840284,   -0.015726732, -0.23041931,
1024       -0.024478018,  -0.10959692,   -0.013875541,  0.18600968,   -0.061274476,
1025       0.0138165,     -0.08160894,   -0.07661644,   0.032372914,  0.16169067,
1026       0.22465782,    -0.03993472,   -0.004017731,  0.08633481,   -0.28869787,
1027       0.08682067,    0.17240396,    0.014975425,   0.056431185,  0.031037588,
1028       0.16702051,    0.0077946745,  0.15140012,    0.29405436,   0.120285,
1029       -0.188994,     -0.027265169,  0.043389652,   -0.022061434, 0.014777949,
1030       -0.20203483,   0.094781205,   0.19100232,    0.13987629,   -0.036132768,
1031       -0.06426278,   -0.05108664,   0.13221376,    0.009441198,  -0.16715929,
1032       0.15859416,    -0.040437475,  0.050779544,   -0.022187516, 0.012166504,
1033       0.027685808,   -0.07675938,   -0.0055694645, -0.09444123,  0.0046453946,
1034       0.050794356,   0.10770313,    -0.20790008,   -0.07149004,  -0.11425117,
1035       0.008225835,   -0.035802525,  0.14374903,    0.15262283,   0.048710253,
1036       0.1847461,     -0.007487823,  0.11000021,    -0.09542012,  0.22619456,
1037       -0.029149994,  0.08527916,    0.009043713,   0.0042746216, 0.016261552,
1038       0.022461696,   0.12689082,    -0.043589946,  -0.12035478,  -0.08361797,
1039       -0.050666027,  -0.1248618,    -0.1275799,    -0.071875185, 0.07377272,
1040       0.09944291,    -0.18897448,   -0.1593054,    -0.06526116,  -0.040107165,
1041       -0.004618631,  -0.067624845,  -0.007576253,  0.10727444,   0.041546922,
1042       -0.20424393,   0.06907816,    0.050412357,   0.00724631,   0.039827548,
1043       0.12449835,    0.10747581,    0.13708383,    0.09134148,   -0.12617786,
1044       -0.06428341,   0.09956831,    0.1208086,     -0.14676677,  -0.0727722,
1045       0.1126304,     0.010139365,   0.015571211,   -0.038128063, 0.022913318,
1046       -0.042050496,  0.16842307,    -0.060597885,  0.10531834,   -0.06411776,
1047       -0.07451711,   -0.03410368,   -0.13393489,   0.06534304,   0.003620307,
1048       0.04490757,    0.05970546,    0.05197996,    0.02839995,   0.10434969,
1049       -0.013699693,  -0.028353551,  -0.07260381,   0.047201227,  -0.024575593,
1050       -0.036445823,  0.07155557,    0.009672501,   -0.02328883,  0.009533515,
1051       -0.03606021,   -0.07421458,   -0.028082801,  -0.2678904,   -0.13221288,
1052       0.18419984,    -0.13012612,   -0.014588381,  -0.035059117, -0.04824723,
1053       0.07830115,    -0.056184657,  0.03277091,    0.025466874,  0.14494097,
1054       -0.12522776,   -0.098633975,  -0.10766018,   -0.08317623,  0.08594209,
1055       0.07749552,    0.039474737,   0.1776665,     -0.07409566,  -0.0477268,
1056       0.29323658,    0.10801441,    0.1154011,     0.013952499,  0.10739139,
1057       0.10708251,    -0.051456142,  0.0074137426,  -0.10430189,  0.10034707,
1058       0.045594677,   0.0635285,     -0.0715442,    -0.089667566, -0.10811871,
1059       0.00026344223, 0.08298446,    -0.009525053,  0.006585689,  -0.24567553,
1060       -0.09450807,   0.09648481,    0.026996298,   -0.06419476,  -0.04752702,
1061       -0.11063944,   -0.23441927,   -0.17608605,   -0.052156363, 0.067035615,
1062       0.19271925,    -0.0032889997, -0.043264326,  0.09663576,   -0.057112187,
1063       -0.10100678,   0.0628376,     0.04447668,    0.017961001,  -0.10094388,
1064       -0.10190601,   0.18335468,    0.10494553,    -0.052095775, -0.0026118709,
1065       0.10539724,    -0.04383912,   -0.042349473,  0.08438151,   -0.1947263,
1066       0.02251204,    0.11216432,    -0.10307853,   0.17351969,   -0.039091777,
1067       0.08066188,    -0.00561982,   0.12633002,    0.11335965,   -0.0088127935,
1068       -0.019777594,  0.06864014,    -0.059751723,  0.016233567,  -0.06894641,
1069       -0.28651384,   -0.004228674,  0.019708522,   -0.16305895,  -0.07468996,
1070       -0.0855457,    0.099339016,   -0.07580735,   -0.13775392,  0.08434318,
1071       0.08330512,    -0.12131499,   0.031935584,   0.09180414,   -0.08876437,
1072       -0.08049874,   0.008753825,   0.03498998,    0.030215185,  0.03907079,
1073       0.089751154,   0.029194152,   -0.03337423,   -0.019092513, 0.04331237,
1074       0.04299654,    -0.036394123,  -0.12915532,   0.09793732,   0.07512415,
1075       -0.11319543,   -0.032502122,  0.15661901,    0.07671967,   -0.005491124,
1076       -0.19379048,   -0.218606,     0.21448623,    0.017840758,  0.1416943,
1077       -0.07051762,   0.19488361,    0.02664691,    -0.18104725,  -0.09334311,
1078       0.15026465,    -0.15493552,   -0.057762887,  -0.11604192,  -0.262013,
1079       -0.01391798,   0.012185008,   0.11156489,    -0.07483202,  0.06693364,
1080       -0.26151478,   0.046425626,   0.036540434,   -0.16435726,  0.17338543,
1081       -0.21401681,   -0.11385144,   -0.08283257,   -0.069031075, 0.030635102,
1082       0.010969227,   0.11109743,    0.010919218,   0.027526086,  0.13519906,
1083       0.01891392,    -0.046839405,  -0.040167913,  0.017953383,  -0.09700955,
1084       0.0061885654,  -0.07000971,   0.026893595,   -0.038844477, 0.14543656};
1085 
1086   lstm_input_ = {// Step 1
1087                  {{0.787926, 0.151646, 0.071352, 0.118426, 0.458058},
1088                   {0.295743, 0.544053, 0.690064, 0.858138, 0.497181}},
1089                  // Step 2
1090                  {{0.596268, 0.998386, 0.568695, 0.864524, 0.571277},
1091                   {0.642421, 0.524260, 0.134799, 0.003639, 0.162482}},
1092                  // Step 3
1093                  {{0.073204, 0.296072, 0.743333, 0.069199, 0.045348},
1094                   {0.640394, 0.930399, 0.050782, 0.432485, 0.988078}},
1095                  // Step 4
1096                  {{0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
1097                   {0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}};
1098 
1099   lstm_golden_output_ = {
1100       {{-0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, -0.0211779,
1101         0.0283512, -0.0114597, 0.00907307, -0.0244004, -0.0152191, -0.0259063,
1102         0.00914318, 0.00415118, 0.017147, 0.0134203},
1103        {-0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, -0.0186926,
1104         0.0193662, -0.0115437, 0.00422612, -0.0345232, 0.00223253, -0.00957321,
1105         0.0210624, 0.013331, 0.0150954, 0.02168}},
1106 
1107       {{-0.0166936, 0.0381209, 0.000889694, 0.0143363, -0.0328911, -0.0234288,
1108         0.0333051, -0.012229, 0.0110322, -0.0457725, -0.000832209, -0.0202817,
1109         0.0327257, 0.0121308, 0.0155969, 0.0312091},
1110        {-0.0141913, 0.0322082, 0.00227024, 0.0260507, -0.0188721, -0.0296489,
1111         0.0399134, -0.0160509, 0.0116039, -0.0447318, -0.0150515, -0.0277406,
1112         0.0316596, 0.0118233, 0.0214762, 0.0293641}},
1113 
1114       {{-0.0213783, 0.0350169, 0.000324794, 0.0276012, -0.0263374, -0.0371449,
1115         0.0446149, -0.0205474, 0.0103729, -0.0576349, -0.0150052, -0.0292043,
1116         0.0376827, 0.0136115, 0.0243435, 0.0354492},
1117        {-0.0204549, 0.0450315, -0.00117378, 0.0167673, -0.0375007, -0.0238314,
1118         0.038784, -0.0174034, 0.0131743, -0.0506589, -0.0048447, -0.0240239,
1119         0.0325789, 0.00790065, 0.0220157, 0.0333314}},
1120 
1121       {{-0.0189322, 0.0464512, -0.00251373, 0.0225745, -0.0308346, -0.0317124,
1122         0.0460407, -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
1123         0.0286833, 0.00824207, 0.0264887, 0.0305169},
1124        {-0.0264787, 0.0387855, -0.000764675, 0.0217599, -0.037537, -0.0335206,
1125         0.0431679, -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
1126         0.0412031, 0.0118723, 0.0239643, 0.0394009}}};
1127 
1128   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1129                    /*use_cifg=*/false, /*use_peephole=*/true,
1130                    /*use_projection_weights=*/true,
1131                    /*use_projection_bias=*/false, weight_type,
1132                    model_has_legacy_20_inputs, /*is_layer_norm=*/false,
1133                    asymmetric_quantize_inputs);
1134 
1135   static const auto* tolerance_per_type = new std::map<TensorType, float>{
1136       {TensorType_FLOAT32, 0.00001f},
1137       {TensorType_UINT8, 0.00467f},
1138       {TensorType_INT8, 0.0015f},
1139   };
1140   VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1141 }
1142 
TEST_P(LstmOpTest,NoCifg_Peephole_Projection_LayerNorm)1143 TEST_P(LstmOpTest, NoCifg_Peephole_Projection_LayerNorm) {
1144   const int n_batch = 2;
1145   const int n_input = 5;
1146   const int n_cell = 4;
1147   const int n_output = 3;
1148 
1149   TensorType weight_type;
1150   // Layer normalization needs 24 inputs.
1151   bool asymmetric_quantize_inputs;
1152   std::tie(weight_type, std::ignore, asymmetric_quantize_inputs) = GetParam();
1153 
1154   // TODO(b/158205028): Fix this test if using NN-API.
1155   if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
1156     return;
1157   }
1158 
1159   input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
1160                              0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
1161                              -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1162 
1163   input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
1164                               -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
1165                               -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
1166 
1167   input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
1168                             -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
1169                             -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
1170 
1171   input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
1172                               -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
1173                               -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
1174 
1175   input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
1176 
1177   forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
1178 
1179   cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
1180 
1181   output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
1182 
1183   recurrent_to_input_weights_ = {-0.2, -0.3, 0.4,  0.1,  -0.5, 0.9,
1184                                  -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1185 
1186   recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
1187                                 -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1188 
1189   recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
1190                                   0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
1191 
1192   recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
1193                                   -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1194 
1195   cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
1196 
1197   cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
1198 
1199   cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
1200 
1201   input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
1202   forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
1203   cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
1204   output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
1205 
1206   projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
1207                          0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
1208 
1209   lstm_input_ = {
1210       {{0.7, 0.8, 0.1, 0.2, 0.3}, {0.3, 0.2, 0.9, 0.8, 0.1}},
1211 
1212       {{0.8, 0.1, 0.2, 0.4, 0.5}, {0.1, 0.5, 0.2, 0.4, 0.2}},
1213 
1214       {{0.2, 0.7, 0.7, 0.1, 0.7}, {0.6, 0.9, 0.2, 0.5, 0.7}},
1215   };
1216 
1217   lstm_golden_output_ = {
1218       {{0.0244077, 0.128027, -0.00170918}, {-0.00692428, 0.0848741, 0.063445}},
1219 
1220       {{0.0137642, 0.140751, 0.0395835}, {-0.00403912, 0.139963, 0.072681}},
1221 
1222       {{-0.00459231, 0.155278, 0.0837377}, {0.00752706, 0.161903, 0.0561371}}};
1223 
1224   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1225                    /*use_cifg=*/false, /*use_peephole=*/true,
1226                    /*use_projection_weights=*/true,
1227                    /*use_projection_bias=*/false, weight_type,
1228                    /*model_has_legacy_20_inputs=*/false,
1229                    /*is_layer_norm=*/true, asymmetric_quantize_inputs);
1230 
1231   static const auto* tolerance_per_type =
1232       new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
1233                                       {TensorType_UINT8, 0.0010907f},
1234                                       {TensorType_INT8, 0.00106f}};
1235   VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1236 }
1237 
TEST_P(LstmOpTest,Cifg_Peephole_Projection_LayerNorm)1238 TEST_P(LstmOpTest, Cifg_Peephole_Projection_LayerNorm) {
1239   const int n_batch = 2;
1240   const int n_input = 5;
1241   const int n_cell = 4;
1242   const int n_output = 3;
1243 
1244   TensorType weight_type;
1245   // Layer normalization needs 24 inputs.
1246   bool asymmetric_quantize_inputs;
1247   std::tie(weight_type, std::ignore, asymmetric_quantize_inputs) = GetParam();
1248 
1249   // TODO(b/158205028): Fix this test if using NN-API.
1250   if (SingleOpModel::GetForceUseNnapi() && weight_type == TensorType_UINT8) {
1251     return;
1252   }
1253 
1254   input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
1255                               -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
1256                               -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
1257   input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
1258                             -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
1259                             -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
1260   input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
1261                               -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
1262                               -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
1263 
1264   forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
1265   cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
1266   output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
1267 
1268   recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
1269                                 -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1270   recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
1271                                   0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
1272   recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
1273                                   -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1274 
1275   cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
1276   cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
1277 
1278   forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
1279   cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
1280   output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
1281   projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
1282                          0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
1283 
1284   lstm_input_ = {{{0.7, 0.8, 0.1, 0.2, 0.3}, {0.3, 0.2, 0.9, 0.8, 0.1}},
1285 
1286                  {{0.8, 0.1, 0.2, 0.4, 0.5}, {0.1, 0.5, 0.2, 0.4, 0.2}},
1287 
1288                  {{0.2, 0.7, 0.7, 0.1, 0.7}, {0.6, 0.9, 0.2, 0.5, 0.7}}};
1289   lstm_golden_output_ = {{{0.02129706, 0.140816242, 0.0112733059},
1290                           {-0.0226350538, 0.0916948169, 0.0769175813}},
1291 
1292                          {{0.0132302344, 0.152308047, 0.0346313119},
1293                           {-0.0269966982, 0.149707705, 0.094149217}},
1294 
1295                          {{-0.0123688057, 0.165790111, 0.0893077999},
1296                           {-0.0103429332, 0.173016444, 0.0720508844}}};
1297 
1298   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
1299                    /*use_cifg=*/true, /*use_peephole=*/true,
1300                    /*use_projection_weights=*/true,
1301                    /*use_projection_bias=*/false, weight_type,
1302                    /*model_has_legacy_20_inputs=*/false,
1303                    /*is_layer_norm=*/true, asymmetric_quantize_inputs);
1304 
1305   static const auto* tolerance_per_type =
1306       new std::map<TensorType, float>{{TensorType_FLOAT32, 0.00001f},
1307                                       {TensorType_UINT8, 0.000971057f},
1308                                       {TensorType_INT8, 0.001f}};
1309   VerifyGoldens(&lstm, tolerance_per_type->at(weight_type));
1310 }
1311 
1312 class LSTMIntegerOpModel : public SingleOpModel {
1313  public:
LSTMIntegerOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,bool use_layer_norm,bool use_8x8_8_implementation,const std::vector<std::pair<float,float>> & ranges,const std::vector<std::pair<float,int>> & intermediates)1314   LSTMIntegerOpModel(int n_batch, int n_input, int n_cell, int n_output,
1315                      bool use_cifg, bool use_peephole,
1316                      bool use_projection_weights, bool use_projection_bias,
1317                      bool use_layer_norm, bool use_8x8_8_implementation,
1318                      const std::vector<std::pair<float, float>>& ranges,
1319                      const std::vector<std::pair<float, int>>& intermediates)
1320       : n_input_(n_input), n_output_(n_output) {
1321     input_ = AddInput({TensorType_INT8,
1322                        {n_batch, n_input},
1323                        ranges[0].first,
1324                        ranges[0].second});
1325 
1326     if (use_cifg) {
1327       input_to_input_weights_ = AddNullInput();
1328     } else {
1329       input_to_input_weights_ = AddInput({TensorType_INT8,
1330                                           {n_cell, n_input},
1331                                           ranges[1].first,
1332                                           ranges[1].second});
1333     }
1334     input_to_forget_weights_ = AddInput({TensorType_INT8,
1335                                          {n_cell, n_input},
1336                                          ranges[2].first,
1337                                          ranges[2].second});
1338     input_to_cell_weights_ = AddInput({TensorType_INT8,
1339                                        {n_cell, n_input},
1340                                        ranges[3].first,
1341                                        ranges[3].second});
1342     input_to_output_weights_ = AddInput({TensorType_INT8,
1343                                          {n_cell, n_input},
1344                                          ranges[4].first,
1345                                          ranges[4].second});
1346 
1347     if (use_cifg) {
1348       recurrent_to_input_weights_ = AddNullInput();
1349     } else {
1350       recurrent_to_input_weights_ = AddInput({TensorType_INT8,
1351                                               {n_cell, n_output},
1352                                               ranges[5].first,
1353                                               ranges[5].second});
1354     }
1355     recurrent_to_forget_weights_ = AddInput({TensorType_INT8,
1356                                              {n_cell, n_output},
1357                                              ranges[6].first,
1358                                              ranges[6].second});
1359     recurrent_to_cell_weights_ = AddInput({TensorType_INT8,
1360                                            {n_cell, n_output},
1361                                            ranges[7].first,
1362                                            ranges[7].second});
1363     recurrent_to_output_weights_ = AddInput({TensorType_INT8,
1364                                              {n_cell, n_output},
1365                                              ranges[8].first,
1366                                              ranges[8].second});
1367 
1368     if (use_peephole) {
1369       if (use_cifg) {
1370         cell_to_input_weights_ = AddNullInput();
1371       } else {
1372         cell_to_input_weights_ = AddInput(
1373             {TensorType_INT16, {n_cell}, ranges[9].first, ranges[9].second});
1374       }
1375       cell_to_forget_weights_ = AddInput(
1376           {TensorType_INT16, {n_cell}, ranges[10].first, ranges[10].second});
1377       cell_to_output_weights_ = AddInput(
1378           {TensorType_INT16, {n_cell}, ranges[11].first, ranges[11].second});
1379     } else {
1380       cell_to_input_weights_ = AddNullInput();
1381       cell_to_forget_weights_ = AddNullInput();
1382       cell_to_output_weights_ = AddNullInput();
1383     }
1384 
1385     if (use_cifg) {
1386       input_gate_bias_ = AddNullInput();
1387     } else {
1388       input_gate_bias_ = AddInput(
1389           {TensorType_INT32, {n_cell}, ranges[12].first, ranges[12].second});
1390     }
1391     forget_gate_bias_ = AddInput(
1392         {TensorType_INT32, {n_cell}, ranges[13].first, ranges[13].second});
1393     cell_gate_bias_ = AddInput(
1394         {TensorType_INT32, {n_cell}, ranges[14].first, ranges[14].second});
1395     output_gate_bias_ = AddInput(
1396         {TensorType_INT32, {n_cell}, ranges[15].first, ranges[15].second});
1397 
1398     if (use_projection_weights) {
1399       projection_weights_ = AddInput({TensorType_INT8,
1400                                       {n_output, n_cell},
1401                                       ranges[16].first,
1402                                       ranges[16].second});
1403     } else {
1404       projection_weights_ = AddNullInput();
1405     }
1406     if (use_projection_bias) {
1407       CHECK(use_projection_weights);
1408       projection_bias_ = AddInput(
1409           {TensorType_INT32, {n_output}, ranges[17].first, ranges[17].second});
1410     } else {
1411       projection_bias_ = AddNullInput();
1412     }
1413 
1414     // Adding the 2 state tensors.
1415     AddVariableInput({TensorType_INT16,
1416                       {n_batch, n_output},
1417                       ranges[18].first,
1418                       ranges[18].second});
1419     AddVariableInput({TensorType_INT16,
1420                       {n_batch, n_cell},
1421                       ranges[19].first,
1422                       ranges[19].second});
1423 
1424     // Layer norm weights.
1425     if (use_layer_norm) {
1426       if (use_cifg) {
1427         input_layer_norm_coefficients_ = AddNullInput();
1428       } else {
1429         input_layer_norm_coefficients_ = AddInput(
1430             {TensorType_INT16, {n_cell}, ranges[20].first, ranges[20].second});
1431       }
1432       forget_layer_norm_coefficients_ = AddInput(
1433           {TensorType_INT16, {n_cell}, ranges[21].first, ranges[21].second});
1434       cell_layer_norm_coefficients_ = AddInput(
1435           {TensorType_INT16, {n_cell}, ranges[22].first, ranges[22].second});
1436       output_layer_norm_coefficients_ = AddInput(
1437           {TensorType_INT16, {n_cell}, ranges[23].first, ranges[23].second});
1438     }
1439 
1440     if (use_8x8_8_implementation) {
1441       EXPECT_EQ(intermediates.size(), 12);
1442     } else {
1443       EXPECT_EQ(intermediates.size(), 5);
1444     }
1445     for (int i = 0; i < intermediates.size(); ++i) {
1446       AddIntermediate(TensorType_INT16, {intermediates[i].first},
1447                       {intermediates[i].second});
1448     }
1449 
1450     output_ = AddOutput({TensorType_INT8,
1451                          {n_batch, n_output},
1452                          ranges[24].first,
1453                          ranges[24].second});
1454 
1455     // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
1456     // default 0.
1457     SetBuiltinOp(
1458         BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
1459         CreateLSTMOptions(builder_, ActivationFunctionType_TANH).Union());
1460 
1461     BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
1462                      /*allow_fp32_relax_to_fp16=*/false,
1463                      /*apply_delegate=*/true, /*allocate_and_delegate=*/false);
1464   }
1465 
PerformAllocateAndDelegate()1466   void PerformAllocateAndDelegate() { AllocateAndDelegate(true); }
1467 
SetInputToInputWeights(const std::vector<float> & f)1468   void SetInputToInputWeights(const std::vector<float>& f) {
1469     QuantizeAndPopulate<int8_t>(input_to_input_weights_, f);
1470   }
1471 
SetInputToForgetWeights(const std::vector<float> & f)1472   void SetInputToForgetWeights(const std::vector<float>& f) {
1473     QuantizeAndPopulate<int8_t>(input_to_forget_weights_, f);
1474   }
1475 
SetInputToCellWeights(const std::vector<float> & f)1476   void SetInputToCellWeights(const std::vector<float>& f) {
1477     QuantizeAndPopulate<int8_t>(input_to_cell_weights_, f);
1478   }
1479 
SetInputToOutputWeights(const std::vector<float> & f)1480   void SetInputToOutputWeights(const std::vector<float>& f) {
1481     QuantizeAndPopulate<int8_t>(input_to_output_weights_, f);
1482   }
1483 
SetRecurrentToInputWeights(const std::vector<float> & f)1484   void SetRecurrentToInputWeights(const std::vector<float>& f) {
1485     QuantizeAndPopulate<int8_t>(recurrent_to_input_weights_, f);
1486   }
1487 
SetRecurrentToForgetWeights(const std::vector<float> & f)1488   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
1489     QuantizeAndPopulate<int8_t>(recurrent_to_forget_weights_, f);
1490   }
1491 
SetRecurrentToCellWeights(const std::vector<float> & f)1492   void SetRecurrentToCellWeights(const std::vector<float>& f) {
1493     QuantizeAndPopulate<int8_t>(recurrent_to_cell_weights_, f);
1494   }
1495 
SetRecurrentToOutputWeights(const std::vector<float> & f)1496   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
1497     QuantizeAndPopulate<int8_t>(recurrent_to_output_weights_, f);
1498   }
1499 
SetCellToInputWeights(const std::vector<float> & f)1500   void SetCellToInputWeights(const std::vector<float>& f) {
1501     QuantizeAndPopulate<int16_t>(cell_to_input_weights_, f);
1502   }
1503 
SetCellToForgetWeights(const std::vector<float> & f)1504   void SetCellToForgetWeights(const std::vector<float>& f) {
1505     QuantizeAndPopulate<int16_t>(cell_to_forget_weights_, f);
1506   }
1507 
SetCellToOutputWeights(const std::vector<float> & f)1508   void SetCellToOutputWeights(const std::vector<float>& f) {
1509     QuantizeAndPopulate<int16_t>(cell_to_output_weights_, f);
1510   }
1511 
SetInputLayerNormCoefficients(const std::vector<float> & f)1512   void SetInputLayerNormCoefficients(const std::vector<float>& f) {
1513     QuantizeAndPopulate<int16_t>(input_layer_norm_coefficients_, f);
1514   }
1515 
SetForgetLayerNormCoefficients(const std::vector<float> & f)1516   void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
1517     QuantizeAndPopulate<int16_t>(forget_layer_norm_coefficients_, f);
1518   }
1519 
SetCellLayerNormCoefficients(const std::vector<float> & f)1520   void SetCellLayerNormCoefficients(const std::vector<float>& f) {
1521     QuantizeAndPopulate<int16_t>(cell_layer_norm_coefficients_, f);
1522   }
1523 
SetOutputLayerNormCoefficients(const std::vector<float> & f)1524   void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
1525     QuantizeAndPopulate<int16_t>(output_layer_norm_coefficients_, f);
1526   }
1527 
SetInputGateBias(const std::vector<float> & f)1528   void SetInputGateBias(const std::vector<float>& f) {
1529     QuantizeAndPopulate<int32_t>(input_gate_bias_, f);
1530   }
1531 
SetForgetGateBias(const std::vector<float> & f)1532   void SetForgetGateBias(const std::vector<float>& f) {
1533     QuantizeAndPopulate<int32_t>(forget_gate_bias_, f);
1534   }
1535 
SetCellBias(const std::vector<float> & f)1536   void SetCellBias(const std::vector<float>& f) {
1537     QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
1538   }
1539 
SetOutputGateBias(const std::vector<float> & f)1540   void SetOutputGateBias(const std::vector<float>& f) {
1541     QuantizeAndPopulate<int32_t>(output_gate_bias_, f);
1542   }
1543 
SetProjectionWeights(const std::vector<float> & f)1544   void SetProjectionWeights(const std::vector<float>& f) {
1545     QuantizeAndPopulate<int8_t>(projection_weights_, f);
1546   }
1547 
SetProjectionBias(const std::vector<float> & f)1548   void SetProjectionBias(const std::vector<float>& f) {
1549     QuantizeAndPopulate<int32_t>(projection_bias_, f);
1550   }
1551 
SetInput(const std::vector<float> & f)1552   void SetInput(const std::vector<float>& f) {
1553     QuantizeAndPopulate<int8_t>(input_, f);
1554   }
1555 
GetOutput()1556   std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
1557 
num_inputs()1558   int num_inputs() { return n_input_; }
num_outputs()1559   int num_outputs() { return n_output_; }
1560 
1561  protected:
1562   int input_;
1563   int input_to_input_weights_;
1564   int input_to_forget_weights_;
1565   int input_to_cell_weights_;
1566   int input_to_output_weights_;
1567 
1568   int recurrent_to_input_weights_;
1569   int recurrent_to_forget_weights_;
1570   int recurrent_to_cell_weights_;
1571   int recurrent_to_output_weights_;
1572 
1573   int cell_to_input_weights_;
1574   int cell_to_forget_weights_;
1575   int cell_to_output_weights_;
1576 
1577   int input_layer_norm_coefficients_;
1578   int forget_layer_norm_coefficients_;
1579   int cell_layer_norm_coefficients_;
1580   int output_layer_norm_coefficients_;
1581 
1582   int input_gate_bias_;
1583   int forget_gate_bias_;
1584   int cell_gate_bias_;
1585   int output_gate_bias_;
1586 
1587   int projection_weights_;
1588   int projection_bias_;
1589 
1590   int output_;
1591 
1592   int n_input_;
1593   int n_output_;
1594 };
1595 
TEST(IntegerLstmOpTest,NoCifg_NoPeephole_Projection_LayerNorm)1596 TEST(IntegerLstmOpTest, NoCifg_NoPeephole_Projection_LayerNorm) {
1597   // Hyper parameters.
1598   const int n_batch = 2;
1599   const int n_input = 5;
1600   const int n_cell = 4;
1601   const int n_output = 3;
1602 
1603   // Model related weights.
1604   const std::vector<float> input_to_input_weights = {
1605       0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
1606       -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1607 
1608   const std::vector<float> input_to_forget_weights = {
1609       -0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
1610       -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
1611 
1612   const std::vector<float> input_to_cell_weights = {
1613       -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1614       0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6};
1615 
1616   const std::vector<float> input_to_output_weights = {
1617       -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1618       0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4};
1619 
1620   const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1621 
1622   const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1623 
1624   const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1625 
1626   const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1627 
1628   const std::vector<float> recurrent_to_input_weights = {
1629       -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1630 
1631   const std::vector<float> recurrent_to_cell_weights = {
1632       -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1633 
1634   const std::vector<float> recurrent_to_forget_weights = {
1635       -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1636 
1637   const std::vector<float> recurrent_to_output_weights = {
1638       0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1639 
1640   const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1641   const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1642                                                              0.3};
1643   const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1644   const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1645                                                              0.5};
1646 
1647   const std::vector<float> projection_weights = {
1648       -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1649 
1650   // Input ranges.
1651   const std::vector<std::pair<float, float>> ranges = {
1652       {-1.0, 127.0 / 128},  // input tensor
1653       {-1.0, 1.0},          // input_to_input_weight tensor
1654       {-1.0, 1.0},          // input_to_forget_weight tensor
1655       {-1.0, 1.0},          // input_to_cell_weight tensor
1656       {-1.0, 1.0},          // input_to_output_weight tensor
1657 
1658       {-1.0, 1.0},  // recurrent_to_input_weight tensor
1659       {-1.0, 1.0},  // recurrent_to_forget_weight tensor
1660       {-1.0, 1.0},  // recurrent_to_cell_weight tensor
1661       {-1.0, 1.0},  // recurrent_to_output_weight tensor
1662 
1663       {-1, 1},  // cell_to_input_weight tensor
1664       {-1, 1},  // cell_to_forget_weight tensor
1665       {-1, 1},  // cell_to_output_weight tensor
1666 
1667       {-100, 100},  // input_gate_bias tensor
1668       {-100, 100},  // forget_gate_bias tensor
1669       {-100, 100},  // cell_gate_bias tensor
1670       {-100, 100},  // output_gate_bias tensor
1671 
1672       {-0.5, 0.5},  // projection_weight tensor
1673       {-1, 1},      // projection_bias tensor
1674 
1675       {-1.0, 32767.0 / 32768},  // output_state tensor
1676       {-1, 1},                  // cell_state tensor
1677 
1678       {-1.00001, 1.0},  // input_layer_norm_coefficient tensor
1679       {-1.00001, 1.0},  // forget_layer_norm_coefficient tensor
1680       {-1.00001, 1.0},  // cell_layer_norm_coefficient tensor
1681       {-1.00001, 1.0},  // output_layer_norm_coefficient tensor
1682       // Output scale is the same as output_state scale and only output_state
1683       // scale is used in the op, so this is only provided for clarity.
1684       {-1.0, 32767.0 / 32768},  // output tensor.
1685   };
1686 
1687   // The scale and zero point of intermediate tensors.
1688   std::vector<std::pair<float, int>> intermediates = {
1689       {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
1690 
1691   // Create model.
1692   LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
1693                           /*use_cifg=*/false, /*use_peephole=*/false,
1694                           /*use_projection_weights=*/true,
1695                           /*use_projection_bias=*/false,
1696                           /*use_layer_norm=*/true,
1697                           /*use_8x8_8_implementation=*/false, ranges,
1698                           intermediates);
1699   // Do allocate.
1700   lstm.PerformAllocateAndDelegate();
1701 
1702   // Set weights.
1703   lstm.SetInputToInputWeights(input_to_input_weights);
1704   lstm.SetInputToCellWeights(input_to_cell_weights);
1705   lstm.SetInputToForgetWeights(input_to_forget_weights);
1706   lstm.SetInputToOutputWeights(input_to_output_weights);
1707 
1708   lstm.SetInputGateBias(input_gate_bias);
1709   lstm.SetCellBias(cell_gate_bias);
1710   lstm.SetForgetGateBias(forget_gate_bias);
1711   lstm.SetOutputGateBias(output_gate_bias);
1712 
1713   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
1714   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
1715   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
1716   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
1717 
1718   lstm.SetProjectionWeights(projection_weights);
1719 
1720   lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
1721   lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
1722   lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
1723   lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
1724 
1725   // Model inputs. sequence -batch - input
1726   const std::vector<std::vector<float>> lstm_input = {
1727       {
1728           0.7, 0.8, 0.1, 0.2, 0.3,  //
1729           0.8, 0.1, 0.2, 0.4, 0.5,  //
1730       },
1731       {
1732           0.2, 0.7, 0.7, 0.1, 0.7,  //
1733           0.3, 0.2, 0.9, 0.8, 0.1,  //
1734       },
1735       {
1736           0.7, 0.8, 0.1, 0.2, 0.3,  //
1737           0.3, 0.2, 0.9, 0.8, 0.1,  //
1738       },
1739   };
1740 
1741   // Expected outputs.
1742   const std::vector<std::vector<int8_t>> expected_output = {
1743       {127, 127, -108, -67, 127, 127},
1744       {-128, 127, 127, -128, 127, 127},
1745       {127, 127, 127, -128, 127, 127},
1746   };
1747 
1748   // Invoke and verify the result.
1749   const int input_sequence_size = lstm_input.size();
1750   EXPECT_GT(input_sequence_size, 0);
1751   for (int i = 0; i < input_sequence_size; ++i) {
1752     lstm.SetInput(lstm_input[i]);
1753     ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1754     EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
1755   }
1756 }
1757 
TEST(IntegerLstmOpTest,NoCifg_Peephole_Projection_LayerNorm)1758 TEST(IntegerLstmOpTest, NoCifg_Peephole_Projection_LayerNorm) {
1759   // Hyper parameters.
1760   const int n_batch = 2;
1761   const int n_input = 5;
1762   const int n_cell = 4;
1763   const int n_output = 3;
1764 
1765   // Model related weights.
1766   const std::vector<float> input_to_input_weights = {
1767       0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
1768       -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1769 
1770   const std::vector<float> input_to_forget_weights = {
1771       -0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
1772       -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
1773 
1774   const std::vector<float> input_to_cell_weights = {
1775       -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1776       0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6};
1777 
1778   const std::vector<float> input_to_output_weights = {
1779       -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1780       0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4};
1781 
1782   const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1783 
1784   const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1785 
1786   const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1787 
1788   const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1789 
1790   const std::vector<float> recurrent_to_input_weights = {
1791       -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1792 
1793   const std::vector<float> recurrent_to_cell_weights = {
1794       -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1795 
1796   const std::vector<float> recurrent_to_forget_weights = {
1797       -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1798 
1799   const std::vector<float> recurrent_to_output_weights = {
1800       0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1801 
1802   const std::vector<float> cell_to_input_weights = {0.3, -0.1, 0.1, -0.2};
1803 
1804   const std::vector<float> cell_to_forget_weights = {0.2, -0.1, 0.1, -0.2};
1805 
1806   const std::vector<float> cell_to_output_weights = {0.3, -0.1, 0.1, -0.3};
1807 
1808   const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1809   const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1810                                                              0.3};
1811   const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1812   const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1813                                                              0.5};
1814 
1815   const std::vector<float> projection_weights = {
1816       -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1817 
1818   // Input ranges.
1819   const std::vector<std::pair<float, float>> ranges = {
1820       {-1.0, 127.0 / 128},  // input tensor
1821       {-1.0, 1.0},          // input_to_input_weight tensor
1822       {-1.0, 1.0},          // input_to_forget_weight tensor
1823       {-1.0, 1.0},          // input_to_cell_weight tensor
1824       {-1.0, 1.0},          // input_to_output_weight tensor
1825 
1826       {-1.0, 1.0},  // recurrent_to_input_weight tensor
1827       {-0.9, 0.9},  // recurrent_to_forget_weight tensor
1828       {-1.0, 1.0},  // recurrent_to_cell_weight tensor
1829       {-1.0, 1.0},  // recurrent_to_output_weight tensor
1830 
1831       {-0.3, 0.3},  // cell_to_input_weight tensor
1832       {-0.3, 0.3},  // cell_to_forget_weight tensor
1833       {-0.3, 0.3},  // cell_to_output_weight tensor
1834 
1835       {-100, 100},  // input_gate_bias tensor
1836       {-100, 80},   // forget_gate_bias tensor
1837       {-100, 100},  // cell_gate_bias tensor
1838       {-100, 100},  // output_gate_bias tensor
1839 
1840       {-0.5, 0.5},  // projection_weight tensor
1841       {-1, 1},      // projection_bias tensor
1842 
1843       {-1.0, 32767.0 / 32768},  // output_state tensor
1844       {-1, 1},                  // cell_state tensor
1845 
1846       {-0.5, 0.5},  // input_layer_norm_coefficient tensor
1847       {-0.5, 0.5},  // forget_layer_norm_coefficient tensor
1848       {-1.0, 1.0},  // cell_layer_norm_coefficient tensor
1849       {-1.0, 1.0},  // output_layer_norm_coefficient tensor
1850       // Output scale is the same as output_state scale and only output_state
1851       // scale is used in the op, so this is only provided for clarity.
1852       {-1.0, 32767.0 / 32768},  // output tensor.
1853   };
1854 
1855   // The scale and zero point of intermediate tensors.
1856   std::vector<std::pair<float, int>> intermediates = {
1857       {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
1858 
1859   // Create model.
1860   LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
1861                           /*use_cifg=*/false, /*use_peephole=*/true,
1862                           /*use_projection_weights=*/true,
1863                           /*use_projection_bias=*/false,
1864                           /*use_layer_norm=*/true,
1865                           /*use_8x8_8_implementation=*/false, ranges,
1866                           intermediates);
1867 
1868   // Do allocate.
1869   lstm.PerformAllocateAndDelegate();
1870 
1871   // Set weights.
1872   lstm.SetInputToInputWeights(input_to_input_weights);
1873   lstm.SetInputToCellWeights(input_to_cell_weights);
1874   lstm.SetInputToForgetWeights(input_to_forget_weights);
1875   lstm.SetInputToOutputWeights(input_to_output_weights);
1876 
1877   lstm.SetInputGateBias(input_gate_bias);
1878   lstm.SetCellBias(cell_gate_bias);
1879   lstm.SetForgetGateBias(forget_gate_bias);
1880   lstm.SetOutputGateBias(output_gate_bias);
1881 
1882   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
1883   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
1884   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
1885   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
1886 
1887   lstm.SetCellToInputWeights(cell_to_input_weights);
1888   lstm.SetCellToForgetWeights(cell_to_forget_weights);
1889   lstm.SetCellToOutputWeights(cell_to_output_weights);
1890 
1891   lstm.SetProjectionWeights(projection_weights);
1892 
1893   lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
1894   lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
1895   lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
1896   lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
1897 
1898   // Model inputs. sequence -batch - input
1899   const std::vector<std::vector<float>> lstm_input = {
1900       {
1901           0.7, 0.8, 0.1, 0.2, 0.3,  //
1902           0.8, 0.1, 0.2, 0.4, 0.5,  //
1903       },
1904       {
1905           0.2, 0.7, 0.7, 0.1, 0.7,  //
1906           0.3, 0.2, 0.9, 0.8, 0.1,  //
1907       },
1908       {
1909           0.7, 0.8, 0.1, 0.2, 0.3,  //
1910           0.3, 0.2, 0.9, 0.8, 0.1,  //
1911       },
1912   };
1913 
1914   // Expected outputs.
1915   const std::vector<std::vector<int8_t>> expected_output = {
1916       {127, 127, -16, -21, 127, 127},
1917       {23, 127, 127, -128, 127, 127},
1918       {127, 127, 127, -128, 127, 127},
1919   };
1920 
1921   // Invoke and verify the result.
1922   const int input_sequence_size = lstm_input.size();
1923   EXPECT_GT(input_sequence_size, 0);
1924   for (int i = 0; i < input_sequence_size; ++i) {
1925     lstm.SetInput(lstm_input[i]);
1926     ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1927     EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
1928   }
1929 }
1930 
TEST(IntegerLstmOpTest,Cifg_NoPeephole_Projection_LayerNorm_8x8_8)1931 TEST(IntegerLstmOpTest, Cifg_NoPeephole_Projection_LayerNorm_8x8_8) {
1932   // Hyper parameters.
1933   const int n_batch = 2;
1934   const int n_input = 5;
1935   const int n_cell = 4;
1936   const int n_output = 3;
1937 
1938   // Model related weights.
1939   const std::vector<float> input_to_input_weights = {
1940       0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
1941       -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
1942 
1943   const std::vector<float> input_to_forget_weights = {
1944       -0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
1945       -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
1946 
1947   const std::vector<float> input_to_cell_weights = {
1948       -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
1949       0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6};
1950 
1951   const std::vector<float> input_to_output_weights = {
1952       -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
1953       0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4};
1954 
1955   const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
1956 
1957   const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
1958 
1959   const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
1960 
1961   const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
1962 
1963   const std::vector<float> recurrent_to_input_weights = {
1964       -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
1965 
1966   const std::vector<float> recurrent_to_cell_weights = {
1967       -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
1968 
1969   const std::vector<float> recurrent_to_forget_weights = {
1970       -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
1971 
1972   const std::vector<float> recurrent_to_output_weights = {
1973       0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
1974 
1975   const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
1976   const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
1977                                                              0.3};
1978   const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
1979   const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
1980                                                              0.5};
1981 
1982   const std::vector<float> projection_weights = {
1983       -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
1984   const std::vector<float> projection_bias = {0.1, 0.3, 0.5};
1985 
1986   // Input ranges.
1987   const std::vector<std::pair<float, float>> ranges = {
1988       {-1.0, 127.0 / 128},  // input tensor
1989       {-1.0, 1.0},          // input_to_input_weight tensor
1990       {-1.0, 1.0},          // input_to_forget_weight tensor
1991       {-1.0, 1.0},          // input_to_cell_weight tensor
1992       {-1.0, 1.0},          // input_to_output_weight tensor
1993 
1994       {-1.0, 1.0},  // recurrent_to_input_weight tensor
1995       {-1.0, 1.0},  // recurrent_to_forget_weight tensor
1996       {-1.0, 1.0},  // recurrent_to_cell_weight tensor
1997       {-1.0, 1.0},  // recurrent_to_output_weight tensor
1998 
1999       {-1, 1},  // cell_to_input_weight tensor
2000       {-1, 1},  // cell_to_forget_weight tensor
2001       {-1, 1},  // cell_to_output_weight tensor
2002 
2003       {-100, 100},  // input_gate_bias tensor
2004       {-100, 100},  // forget_gate_bias tensor
2005       {-100, 100},  // cell_gate_bias tensor
2006       {-100, 100},  // output_gate_bias tensor
2007 
2008       {-0.5, 0.5},  // projection_weight tensor
2009       {-1, 1},      // projection_bias tensor
2010 
2011       {-1.0, 32767.0 / 32768},  // output_state tensor
2012       {-1.0, 32767.0 / 32768},  // cell_state tensor
2013 
2014       {-1.00001, 1.0},  // input_layer_norm_coefficient tensor
2015       {-1.00001, 1.0},  // forget_layer_norm_coefficient tensor
2016       {-1.00001, 1.0},  // cell_layer_norm_coefficient tensor
2017       {-1.00001, 1.0},  // output_layer_norm_coefficient tensor
2018       // Output scale is the same as output_state scale and only output_state
2019       // scale is used in the op, so this is only provided for clarity.
2020       {-1.0, 32767.0 / 32768},  // output tensor.
2021   };
2022 
2023   // The scale and zero point of intermediate tensors.
2024   std::vector<std::pair<float, int>> intermediates = {
2025       {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0},
2026       {0.007, 0},    {0.007059, 0}, {0.007, 0},    {0.007, 0},
2027       {0.007059, 0}, {0.007, 0},    {0.007, 0},    {0.3, 0}};
2028 
2029   // Create model.
2030   LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output,
2031                           /*use_cifg=*/true, /*use_peephole=*/false,
2032                           /*use_projection_weights=*/true,
2033                           /*use_projection_bias=*/true,
2034                           /*use_layer_norm=*/true,
2035                           /*use_8x8_8_implementation=*/true, ranges,
2036                           intermediates);
2037 
2038   // Do allocate.
2039   lstm.PerformAllocateAndDelegate();
2040 
2041   // Set weights.
2042   // lstm.SetInputToInputWeights(input_to_input_weights);
2043   lstm.SetInputToCellWeights(input_to_cell_weights);
2044   lstm.SetInputToForgetWeights(input_to_forget_weights);
2045   lstm.SetInputToOutputWeights(input_to_output_weights);
2046 
2047   // lstm.SetInputGateBias(input_gate_bias);
2048   lstm.SetCellBias(cell_gate_bias);
2049   lstm.SetForgetGateBias(forget_gate_bias);
2050   lstm.SetOutputGateBias(output_gate_bias);
2051 
2052   // lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
2053   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
2054   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
2055   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
2056 
2057   lstm.SetProjectionWeights(projection_weights);
2058   lstm.SetProjectionBias(projection_bias);
2059 
2060   // lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
2061   lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
2062   lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
2063   lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
2064 
2065   // Model inputs. sequence -batch - input
2066   const std::vector<std::vector<float>> lstm_input = {
2067       {
2068           0.7, 0.8, 0.1, 0.2, 0.3,  //
2069           0.8, 0.1, 0.2, 0.4, 0.5,  //
2070       },
2071       {
2072           0.2, 0.7, 0.7, 0.1, 0.7,  //
2073           0.3, 0.2, 0.9, 0.8, 0.1,  //
2074       },
2075       {
2076           0.7, 0.8, 0.1, 0.2, 0.3,  //
2077           0.3, 0.2, 0.9, 0.8, 0.1,  //
2078       },
2079   };
2080 
2081   // Expected outputs.
2082   const std::vector<std::vector<int8_t>> expected_output = {
2083       {127, 127, 127, 127, 127, 127},
2084       {127, 127, 127, 127, 127, 127},
2085       {127, 127, 127, 127, 127, 127},
2086   };
2087 
2088   // Invoke and verify the result.
2089   const int input_sequence_size = lstm_input.size();
2090   EXPECT_GT(input_sequence_size, 0);
2091   for (int i = 0; i < input_sequence_size; ++i) {
2092     lstm.SetInput(lstm_input[i]);
2093     ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2094     EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
2095   }
2096 }
2097 
2098 #ifdef GTEST_HAS_DEATH_TEST
TEST(LstmOpTest,InvalidTypes)2099 TEST(LstmOpTest, InvalidTypes) {
2100   const int n_batch = 1;
2101   const int n_input = 2;
2102   const int n_cell = 4;
2103   const int n_output = 4;
2104 
2105   EXPECT_DEATH(LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2106                                 /*use_cifg=*/false, /*use_peephole=*/false,
2107                                 /*use_projection_weights=*/false,
2108                                 /*use_projection_bias=*/false,
2109                                 /*weight_type=*/TensorType_INT32,
2110                                 /*model_has_legacy_20_inputs=*/true,
2111                                 /*is_layer_norm=*/false,
2112                                 /*asymmetric_quantize_inputs=*/false),
2113                "");
2114 
2115   EXPECT_DEATH(LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2116                                 /*use_cifg=*/false, /*use_peephole=*/false,
2117                                 /*use_projection_weights=*/false,
2118                                 /*use_projection_bias=*/false,
2119                                 /*weight_type=*/TensorType_COMPLEX64,
2120                                 /*model_has_legacy_20_inputs=*/true,
2121                                 /*is_layer_norm=*/false,
2122                                 /*asymmetric_quantize_inputs=*/false),
2123                "");
2124 }
2125 #endif
2126 
2127 class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel {
2128  public:
HybridSparseLSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorData & input_weights_td,const std::vector<float> & input_to_input_weights,const std::vector<float> & input_to_forget_weights,const std::vector<float> & input_to_cell_weights,const std::vector<float> & input_to_output_weights,const TensorData & recurrent_weights_td,const std::vector<float> & recurrent_to_input_weights,const std::vector<float> & recurrent_to_forget_weights,const std::vector<float> & recurrent_to_cell_weights,const std::vector<float> & recurrent_to_output_weights,const::tflite::TensorType & weight_type=::tflite::TensorType_INT8)2129   HybridSparseLSTMOpModel(
2130       int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
2131       bool use_peephole, bool use_projection_weights, bool use_projection_bias,
2132       float cell_clip, float proj_clip,
2133       const std::vector<std::vector<int>>& input_shapes,
2134       const TensorData& input_weights_td,
2135       const std::vector<float>& input_to_input_weights,
2136       const std::vector<float>& input_to_forget_weights,
2137       const std::vector<float>& input_to_cell_weights,
2138       const std::vector<float>& input_to_output_weights,
2139       const TensorData& recurrent_weights_td,
2140       const std::vector<float>& recurrent_to_input_weights,
2141       const std::vector<float>& recurrent_to_forget_weights,
2142       const std::vector<float>& recurrent_to_cell_weights,
2143       const std::vector<float>& recurrent_to_output_weights,
2144       const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8)
2145       : n_batch_(n_batch),
2146         n_input_(n_input),
2147         n_cell_(n_cell),
2148         n_output_(n_output) {
2149     input_ = AddInput(::tflite::TensorType_FLOAT32);
2150 
2151     if (use_cifg) {
2152       input_to_input_weights_ = AddNullInput();
2153     } else {
2154       input_to_input_weights_ =
2155           AddConstSparseInput(input_weights_td, input_to_input_weights, true);
2156     }
2157 
2158     input_to_forget_weights_ =
2159         AddConstSparseInput(input_weights_td, input_to_forget_weights, true);
2160 
2161     input_to_cell_weights_ =
2162         AddConstSparseInput(input_weights_td, input_to_cell_weights, true);
2163 
2164     input_to_output_weights_ =
2165         AddConstSparseInput(input_weights_td, input_to_output_weights, true);
2166 
2167     if (use_cifg) {
2168       recurrent_to_input_weights_ = AddNullInput();
2169     } else {
2170       recurrent_to_input_weights_ = AddConstSparseInput(
2171           recurrent_weights_td, recurrent_to_input_weights, true);
2172     }
2173 
2174     recurrent_to_forget_weights_ = AddConstSparseInput(
2175         recurrent_weights_td, recurrent_to_forget_weights, true);
2176     recurrent_to_cell_weights_ = AddConstSparseInput(
2177         recurrent_weights_td, recurrent_to_cell_weights, true);
2178     recurrent_to_output_weights_ = AddConstSparseInput(
2179         recurrent_weights_td, recurrent_to_output_weights, true);
2180 
2181     if (use_peephole) {
2182       if (use_cifg) {
2183         cell_to_input_weights_ = AddNullInput();
2184       } else {
2185         cell_to_input_weights_ = AddInput(weight_type);
2186       }
2187       cell_to_forget_weights_ = AddInput(weight_type);
2188       cell_to_output_weights_ = AddInput(weight_type);
2189     } else {
2190       cell_to_input_weights_ = AddNullInput();
2191       cell_to_forget_weights_ = AddNullInput();
2192       cell_to_output_weights_ = AddNullInput();
2193     }
2194 
2195     if (use_cifg) {
2196       input_gate_bias_ = AddNullInput();
2197     } else {
2198       input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2199     }
2200     forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2201     cell_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2202     output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2203 
2204     if (use_projection_weights) {
2205       projection_weights_ = AddInput(weight_type);
2206       if (use_projection_bias) {
2207         projection_bias_ = AddInput(::tflite::TensorType_FLOAT32);
2208       } else {
2209         projection_bias_ = AddNullInput();
2210       }
2211     } else {
2212       projection_weights_ = AddNullInput();
2213       projection_bias_ = AddNullInput();
2214     }
2215 
2216     // Adding the 2 state tensors.
2217     output_state_ = AddVariableInput(::tflite::TensorData{
2218         ::tflite::TensorType_FLOAT32, {n_output_ * n_batch_}});
2219     cell_state_ = AddVariableInput(::tflite::TensorData{
2220         ::tflite::TensorType_FLOAT32, {n_cell_ * n_batch_}});
2221 
2222     if (use_cifg) {
2223       input_layer_norm_weights_ = AddNullInput();
2224     } else {
2225       input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2226     }
2227     forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2228     cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2229     output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
2230 
2231     output_ = AddOutput(::tflite::TensorType_FLOAT32);
2232 
2233     SetBuiltinOp(
2234         BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
2235         CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
2236                           proj_clip, LSTMKernelType_FULL, false)
2237             .Union());
2238     BuildInterpreter(input_shapes);
2239   }
2240 
SetCellToInputWeights(std::vector<float> f)2241   void SetCellToInputWeights(std::vector<float> f) {
2242     SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
2243   }
2244 
SetCellToForgetWeights(std::vector<float> f)2245   void SetCellToForgetWeights(std::vector<float> f) {
2246     SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
2247   }
2248 
SetCellToOutputWeights(std::vector<float> f)2249   void SetCellToOutputWeights(std::vector<float> f) {
2250     SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
2251   }
2252 
SetInputLayerNormWeights(std::vector<float> f)2253   void SetInputLayerNormWeights(std::vector<float> f) {
2254     PopulateTensor(input_layer_norm_weights_, f);
2255   }
2256 
SetForgetLayerNormWeights(std::vector<float> f)2257   void SetForgetLayerNormWeights(std::vector<float> f) {
2258     PopulateTensor(forget_layer_norm_weights_, f);
2259   }
2260 
SetCellLayerNormWeights(std::vector<float> f)2261   void SetCellLayerNormWeights(std::vector<float> f) {
2262     PopulateTensor(cell_layer_norm_weights_, f);
2263   }
2264 
SetOutputLayerNormWeights(std::vector<float> f)2265   void SetOutputLayerNormWeights(std::vector<float> f) {
2266     PopulateTensor(output_layer_norm_weights_, f);
2267   }
2268 
SetInputGateBias(std::vector<float> f)2269   void SetInputGateBias(std::vector<float> f) {
2270     PopulateTensor(input_gate_bias_, f);
2271   }
2272 
SetForgetGateBias(std::vector<float> f)2273   void SetForgetGateBias(std::vector<float> f) {
2274     PopulateTensor(forget_gate_bias_, f);
2275   }
2276 
SetCellBias(std::vector<float> f)2277   void SetCellBias(std::vector<float> f) { PopulateTensor(cell_bias_, f); }
2278 
SetOutputGateBias(std::vector<float> f)2279   void SetOutputGateBias(std::vector<float> f) {
2280     PopulateTensor(output_gate_bias_, f);
2281   }
2282 
SetProjectionWeights(std::vector<float> f)2283   void SetProjectionWeights(std::vector<float> f) {
2284     SignedSymmetricQuantizeAndPopulate(projection_weights_, f);
2285   }
2286 
SetProjectionBias(std::vector<float> f)2287   void SetProjectionBias(std::vector<float> f) {
2288     PopulateTensor(projection_bias_, f);
2289   }
2290 
SetInput(int offset,const float * begin,const float * end)2291   void SetInput(int offset, const float* begin, const float* end) {
2292     PopulateTensor(input_, offset, const_cast<float*>(begin),
2293                    const_cast<float*>(end));
2294   }
2295 
GetOutput()2296   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2297 
num_inputs()2298   int num_inputs() { return n_input_; }
num_outputs()2299   int num_outputs() { return n_output_; }
num_cells()2300   int num_cells() { return n_cell_; }
num_batches()2301   int num_batches() { return n_batch_; }
2302 
2303  protected:
2304   int input_;
2305   int input_to_input_weights_;
2306   int input_to_forget_weights_;
2307   int input_to_cell_weights_;
2308   int input_to_output_weights_;
2309 
2310   int recurrent_to_input_weights_;
2311   int recurrent_to_forget_weights_;
2312   int recurrent_to_cell_weights_;
2313   int recurrent_to_output_weights_;
2314 
2315   int cell_to_input_weights_;
2316   int cell_to_forget_weights_;
2317   int cell_to_output_weights_;
2318 
2319   int input_layer_norm_weights_;
2320   int forget_layer_norm_weights_;
2321   int cell_layer_norm_weights_;
2322   int output_layer_norm_weights_;
2323 
2324   int input_gate_bias_;
2325   int forget_gate_bias_;
2326   int cell_bias_;
2327   int output_gate_bias_;
2328 
2329   int projection_weights_;
2330   int projection_bias_;
2331 
2332   int output_state_;
2333   int cell_state_;
2334 
2335   int output_;
2336 
2337   int n_batch_;
2338   int n_input_;
2339   int n_cell_;
2340   int n_output_;
2341 };
2342 
2343 class BaseSparseLstmTest : public ::testing::Test {
2344  protected:
2345   // Weights of the Sparse Layer Norm LSTM model. Some are optional.
2346   std::vector<float> input_to_input_weights_;
2347   std::vector<float> input_to_cell_weights_;
2348   std::vector<float> input_to_forget_weights_;
2349   std::vector<float> input_to_output_weights_;
2350   std::vector<float> input_gate_bias_;
2351   std::vector<float> cell_gate_bias_;
2352   std::vector<float> forget_gate_bias_;
2353   std::vector<float> output_gate_bias_;
2354   std::vector<float> recurrent_to_input_weights_;
2355   std::vector<float> recurrent_to_cell_weights_;
2356   std::vector<float> recurrent_to_forget_weights_;
2357   std::vector<float> recurrent_to_output_weights_;
2358   std::vector<float> cell_to_input_weights_;
2359   std::vector<float> cell_to_forget_weights_;
2360   std::vector<float> cell_to_output_weights_;
2361   std::vector<float> input_layer_norm_weights_;
2362   std::vector<float> forget_layer_norm_weights_;
2363   std::vector<float> cell_layer_norm_weights_;
2364   std::vector<float> output_layer_norm_weights_;
2365   std::vector<float> projection_weights_;
2366 
2367   std::vector<int> input_to_input_weights_size_;
2368   std::vector<int> input_to_cell_weights_size_;
2369   std::vector<int> input_to_forget_weights_size_;
2370   std::vector<int> input_to_output_weights_size_;
2371   std::vector<int> recurrent_to_input_weights_size_;
2372   std::vector<int> recurrent_to_cell_weights_size_;
2373   std::vector<int> recurrent_to_forget_weights_size_;
2374   std::vector<int> recurrent_to_output_weights_size_;
2375 
2376   int n_batch_;
2377   int n_input_;
2378   int n_cell_;
2379   int n_output_;
2380   float cell_clip_;
2381   float proj_clip_;
2382 
2383   // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
2384   std::vector<std::vector<float>> sparse_layer_norm_lstm_input_;
2385 
2386   // Compares output up to tolerance to the result of the layer_norm_lstm given
2387   // the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,HybridSparseLSTMOpModel * sparse_layer_norm_lstm,float tolerance=1e-5)2388   void VerifyGoldens(const std::vector<std::vector<float>>& input,
2389                      const std::vector<std::vector<float>>& output,
2390                      HybridSparseLSTMOpModel* sparse_layer_norm_lstm,
2391                      float tolerance = 1e-5) {
2392     const int num_batches = input.size();
2393     EXPECT_GT(num_batches, 0);
2394     const int num_inputs = sparse_layer_norm_lstm->num_inputs();
2395     EXPECT_GT(num_inputs, 0);
2396     const int input_sequence_size = input[0].size() / num_inputs;
2397     EXPECT_GT(input_sequence_size, 0);
2398     for (int i = 0; i < input_sequence_size; ++i) {
2399       for (int b = 0; b < num_batches; ++b) {
2400         const float* batch_start = input[b].data() + i * num_inputs;
2401         const float* batch_end = batch_start + num_inputs;
2402 
2403         sparse_layer_norm_lstm->SetInput(
2404             b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end);
2405       }
2406 
2407       ASSERT_EQ(sparse_layer_norm_lstm->Invoke(), kTfLiteOk);
2408 
2409       const int num_outputs = sparse_layer_norm_lstm->num_outputs();
2410       std::vector<float> expected;
2411       for (int b = 0; b < num_batches; ++b) {
2412         const float* golden_start_batch = output[b].data() + i * num_outputs;
2413         const float* golden_end_batch = golden_start_batch + num_outputs;
2414         expected.insert(expected.end(), golden_start_batch, golden_end_batch);
2415       }
2416       EXPECT_THAT(
2417           sparse_layer_norm_lstm->GetOutput(),
2418           ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance)));
2419     }
2420   }
2421 };
2422 
2423 class NoCifgPeepholeProjectionNoClippingSparseLstmTest
2424     : public BaseSparseLstmTest {
SetUp()2425   void SetUp() override {
2426     n_batch_ = 2;
2427     n_input_ = 48;
2428     n_cell_ = 4;
2429     n_output_ = 16;
2430     cell_clip_ = 0.0;
2431     proj_clip_ = 0.0;
2432 
2433     /* clang-format off */
2434     input_to_input_weights_ = {
2435       /* 1st row */
2436       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2437       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2438       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2439       39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2440       /* 2nd row */
2441       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2442       0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2443       -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2444       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2445       /* 3rd row */
2446       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2447       0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2448       -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2449       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2450       /* 4th row */
2451       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2452       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2453       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2454       38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2455     input_to_input_weights_size_ = {4, 48};
2456 
2457     input_to_forget_weights_ = {
2458       /* 1st row */
2459       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2460       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2461       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2462       39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2463       /* 2nd row */
2464       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2465       0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2466       -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2467       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2468       /* 3rd row */
2469       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2470       0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2471       -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2472       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2473       /* 4th row */
2474       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2475       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2476       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2477       38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2478     input_to_forget_weights_size_ = {4, 48};
2479 
2480     input_to_cell_weights_ = {
2481       /* 1st row */
2482       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2483       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2484       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2485       39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2486       /* 2nd row */
2487       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2488       0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2489       -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2490       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2491       /* 3rd row */
2492       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2493       0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2494       -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2495       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2496       /* 4th row */
2497       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2498       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2499       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2500       38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2501     input_to_cell_weights_size_ = {4, 48};
2502 
2503     input_to_output_weights_ = {
2504       /* 1st row */
2505       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
2506       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2507       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
2508       39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
2509       /* 2nd row */
2510       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2511       0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
2512       -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2513       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2514       /* 3rd row */
2515       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2516       0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
2517       -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2518       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2519       /* 4th row */
2520       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
2521       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2522       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
2523       38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
2524     input_to_output_weights_size_ = {4, 48};
2525 
2526     input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
2527 
2528     forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
2529 
2530     cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
2531 
2532     output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
2533 
2534     recurrent_to_input_weights_ = {
2535       -0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2536       0.0, 0.0,   // 1st row
2537       0.1,  -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2538       0.0, 0.0,   // 2nd row
2539       -0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2540       0.0, 0.0,  // 3rd row
2541       0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2542       0.0, 0.0,  // 4th row
2543     };
2544     recurrent_to_input_weights_size_ = {4, 16};
2545 
2546     recurrent_to_cell_weights_ = {
2547       -0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2548       0.0, 0.0,     // 1st row
2549       -0.3, 0.8,  -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2550       0.0, 0.0,  // 2nd row
2551       -0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2552       0.0, 0.0,     // 3rd row
2553       -0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2554       0.0, 0.0,    // 4th row
2555     };
2556     recurrent_to_cell_weights_size_ = {4, 16};
2557 
2558     recurrent_to_forget_weights_ = {
2559       -0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2560       0.0, 0.0,  // 1st row
2561       -0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2562       0.0, 0.0,  // 2nd row
2563       0.9,  0.3,  -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2564       0.0, 0.0,  // 3rd row
2565       0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2566       0.0, 0.0,    // 4th row
2567     };
2568     recurrent_to_forget_weights_size_ = {4, 16};
2569 
2570     recurrent_to_output_weights_ = {
2571       0.3,  -0.1, 0.1,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2572       0.0, 0.0,  // 1st row
2573       -0.2, -0.5, -0.7,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2574       0.0, 0.0,  // 2nd row
2575       -0.2, -0.6, -0.1,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2576       0.0, 0.0,  // 3rd row
2577       -0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2578       0.0, 0.0,  // 4th row
2579     };
2580     recurrent_to_output_weights_size_ = {4, 16};
2581 
2582     cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
2583 
2584     cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
2585 
2586     cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
2587 
2588     input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
2589     forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
2590     cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
2591     output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
2592 
2593     projection_weights_ = {
2594       -0.1, 0.2, 0.01, -0.2,  // 1st row
2595       0.1, 0.5, 0.3, 0.08,    // 2nd row
2596       0.07, 0.2, -0.4, 0.2,   // 3rd row
2597       0.0, 0.0, 0.0, 0.0,     // 4th row
2598       0.0, 0.0, 0.0, 0.0,     // 5th row
2599       0.0, 0.0, 0.0, 0.0,     // 6th row
2600       0.0, 0.0, 0.0, 0.0,     // 7th row
2601       0.0, 0.0, 0.0, 0.0,     // 8th row
2602       0.0, 0.0, 0.0, 0.0,     // 9th row
2603       0.0, 0.0, 0.0, 0.0,     // 10th row
2604       0.0, 0.0, 0.0, 0.0,     // 11th row
2605       0.0, 0.0, 0.0, 0.0,     // 12th row
2606       0.0, 0.0, 0.0, 0.0,     // 13th row
2607       0.0, 0.0, 0.0, 0.0,     // 14th row
2608       0.0, 0.0, 0.0, 0.0,     // 15th row
2609       0.0, 0.0, 0.0, 0.0,     // 16th row
2610     };
2611 
2612     sparse_layer_norm_lstm_input_ = {
2613       // Batch0: 2 (input_sequence_size) * 45 (n_input_)
2614       {
2615         1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2616         -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
2617         1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2618         -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,  // seq 0
2619         2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
2620         -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
2621         0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
2622         1.0, -2.5, 0.7, -1.9, 0.2,  0.1, 0.2, 0.3,  // seq 1
2623       },
2624       // Batch1: 2 (input_sequence_size) * 45 (n_input_)
2625       {
2626         1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2627         -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
2628         1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
2629         -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,  // seq 0
2630         2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
2631         -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
2632         0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
2633         1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0,   // seq 1
2634       },
2635     };
2636     /* clang-format on */
2637   }
2638 };
2639 
TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,HybridSparseLstmBlackBoxTest)2640 TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,
2641        HybridSparseLstmBlackBoxTest) {
2642   TensorData input_weight = {};
2643   input_weight.type = TensorType_FLOAT32;
2644   input_weight.shape = {4, 48};
2645   input_weight.traversal_order = {0, 1, 2};
2646   input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2647   input_weight.block_map = {1};
2648   input_weight.block_size = {16};
2649   TensorData recurrent_weight = {};
2650   recurrent_weight.type = TensorType_FLOAT32;
2651   recurrent_weight.shape = {4, 16};
2652   recurrent_weight.traversal_order = {0, 1, 2};
2653   recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2654   recurrent_weight.block_map = {1};
2655   recurrent_weight.block_size = {16};
2656   HybridSparseLSTMOpModel sparse_layer_norm_lstm(
2657       n_batch_, n_input_, n_cell_, n_output_,
2658       /*use_cifg=*/false, /*use_peephole=*/true,
2659       /*use_projection_weights=*/true,
2660       /*use_projection_bias=*/false, cell_clip_, proj_clip_,
2661       {
2662           {n_batch_, n_input_},  // input tensor
2663 
2664           {input_to_input_weights_size_},
2665           {input_to_forget_weights_size_},
2666           {input_to_cell_weights_size_},
2667           {input_to_output_weights_size_},
2668 
2669           {recurrent_to_input_weights_size_},
2670           {recurrent_to_forget_weights_size_},
2671           {recurrent_to_cell_weights_size_},
2672           {recurrent_to_output_weights_size_},
2673 
2674           {n_cell_},  // cell_to_input_weight tensor
2675           {n_cell_},  // cell_to_forget_weight tensor
2676           {n_cell_},  // cell_to_output_weight tensor
2677 
2678           {n_cell_},  // input_gate_bias tensor
2679           {n_cell_},  // forget_gate_bias tensor
2680           {n_cell_},  // cell_bias tensor
2681           {n_cell_},  // output_gate_bias tensor
2682 
2683           {n_output_, n_cell_},  // projection_weight tensor
2684           {0},                   // projection_bias tensor
2685 
2686           {n_output_ * n_batch_},  // output_state tensor
2687           {n_cell_ * n_batch_},    // cell_state tensor
2688 
2689           {n_cell_},  // input_layer_norm_weight tensor
2690           {n_cell_},  // forget_layer_norm_weight tensor
2691           {n_cell_},  // cell_layer_norm_weight tensor
2692           {n_cell_},  // output_layer_norm_weight tensor
2693       },
2694       input_weight, input_to_input_weights_, input_to_forget_weights_,
2695       input_to_cell_weights_, input_to_output_weights_, recurrent_weight,
2696       recurrent_to_input_weights_, recurrent_to_forget_weights_,
2697       recurrent_to_cell_weights_, recurrent_to_output_weights_);
2698 
2699   sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_);
2700   sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_);
2701   sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
2702   sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
2703 
2704   sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
2705   sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
2706   sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
2707 
2708   sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
2709   sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
2710   sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
2711   sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
2712 
2713   sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_);
2714 
2715   /* clang-format off */
2716   const std::vector<std::vector<float>> sparse_layer_norm_lstm_golden_output = {
2717     {
2718       // Batch0: 2 (input_sequence_size) * 3 (n_output_)
2719       0.0559981, 0.140761, -0.0618812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2720       0.0, 0.0, 0.0, 0.0, 0.0,
2721       0.070831, 0.200455, -0.0581763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2722       0.0, 0.0, 0.0, 0.0, 0.0,
2723     },
2724     {
2725       // Batch1: 3 (input_sequence_size) * 3 (n_output_)
2726       0.0559981, 0.140761, -0.0618812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2727       0.0, 0.0, 0.0, 0.0, 0.0,
2728       0.070831, 0.200455, -0.0581763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2729       0.0, 0.0, 0.0, 0.0, 0.0,
2730     }};
2731   /* clang-format on */
2732 
2733   VerifyGoldens(sparse_layer_norm_lstm_input_,
2734                 sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm);
2735 }
2736 
2737 // Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
2738 INSTANTIATE_TEST_SUITE_P(
2739     Parameterized, LstmOpTest,
2740     ::testing::Combine(::testing::Values(TensorType_FLOAT32, TensorType_UINT8,
2741                                          TensorType_INT8),
2742                        ::testing::Bool(), ::testing::Bool()));
2743 
2744 }  // namespace
2745 }  // namespace tflite
2746