xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/NEON/LSTMLayerQuantized.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2019-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "tests/NEON/Accessor.h"
27*c217d954SCole Faust #include "tests/PaddingCalculator.h"
28*c217d954SCole Faust #include "tests/Utils.h"
29*c217d954SCole Faust #include "tests/datasets/LSTMLayerDataset.h"
30*c217d954SCole Faust #include "tests/framework/Asserts.h"
31*c217d954SCole Faust #include "tests/framework/Macros.h"
32*c217d954SCole Faust #include "tests/framework/datasets/Datasets.h"
33*c217d954SCole Faust #include "tests/validation/Validation.h"
34*c217d954SCole Faust 
35*c217d954SCole Faust #include <vector>
36*c217d954SCole Faust 
37*c217d954SCole Faust namespace arm_compute
38*c217d954SCole Faust {
39*c217d954SCole Faust namespace test
40*c217d954SCole Faust {
41*c217d954SCole Faust namespace validation
42*c217d954SCole Faust {
43*c217d954SCole Faust namespace
44*c217d954SCole Faust {
45*c217d954SCole Faust template <typename T>
fill_tensor(Tensor & tensor,const std::vector<T> & v)46*c217d954SCole Faust inline void fill_tensor(Tensor &tensor, const std::vector<T> &v)
47*c217d954SCole Faust {
48*c217d954SCole Faust     // Import memory accounting for padding
49*c217d954SCole Faust     TensorShape t_shape = tensor.info()->tensor_shape();
50*c217d954SCole Faust     Window      window;
51*c217d954SCole Faust     window.use_tensor_dimensions(t_shape);
52*c217d954SCole Faust     Iterator out(&tensor, window);
53*c217d954SCole Faust     execute_window_loop(window, [&](const Coordinates & id)
54*c217d954SCole Faust     {
55*c217d954SCole Faust         *reinterpret_cast<T *>(out.ptr()) = v[coord2index(t_shape, id)];
56*c217d954SCole Faust     },
57*c217d954SCole Faust     out);
58*c217d954SCole Faust }
59*c217d954SCole Faust 
60*c217d954SCole Faust template <typename T>
fill_tensor(SimpleTensor<T> & tensor,const std::vector<T> & v)61*c217d954SCole Faust inline void fill_tensor(SimpleTensor<T> &tensor, const std::vector<T> &v)
62*c217d954SCole Faust {
63*c217d954SCole Faust     std::memcpy(tensor.data(), v.data(), sizeof(T) * v.size());
64*c217d954SCole Faust }
65*c217d954SCole Faust 
66*c217d954SCole Faust /** Tolerance for quantized asymmetric operations */
67*c217d954SCole Faust #if defined(__aarch64__)
68*c217d954SCole Faust constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(0);
69*c217d954SCole Faust #else  // defined(__aarch64__)
70*c217d954SCole Faust constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1);
71*c217d954SCole Faust #endif // defined(__aarch64__)
72*c217d954SCole Faust 
73*c217d954SCole Faust } // namespace
74*c217d954SCole Faust 
75*c217d954SCole Faust TEST_SUITE(NEON)
TEST_SUITE(LSTMLayerQuantized)76*c217d954SCole Faust TEST_SUITE(LSTMLayerQuantized)
77*c217d954SCole Faust 
78*c217d954SCole Faust // *INDENT-OFF*
79*c217d954SCole Faust // clang-format off
80*c217d954SCole Faust TEST_SUITE(IntegrationTestCase)
81*c217d954SCole Faust TEST_SUITE(MultSmallerEq1)
82*c217d954SCole Faust TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
83*c217d954SCole Faust {
84*c217d954SCole Faust     const int batch_size  = 2;
85*c217d954SCole Faust     const int input_size  = 2;
86*c217d954SCole Faust     const int output_size = 4;
87*c217d954SCole Faust 
88*c217d954SCole Faust 
89*c217d954SCole Faust     QuantizationInfo qasymm(1.f / 128.f, 128);
90*c217d954SCole Faust     QuantizationInfo qweights(1.f / 128.f, 128);
91*c217d954SCole Faust     QuantizationInfo qsymm_3(8.f / 32768.f, 0);
92*c217d954SCole Faust     QuantizationInfo qsymm_4(16.f / 32768.f, 0);
93*c217d954SCole Faust 
94*c217d954SCole Faust     TensorShape input_shape{ input_size, batch_size };
95*c217d954SCole Faust     TensorShape input_weights_shape{ input_size, output_size };
96*c217d954SCole Faust     TensorShape recurrent_weights_shape{ output_size, output_size };
97*c217d954SCole Faust     TensorShape output_shape{ output_size, batch_size};
98*c217d954SCole Faust     TensorShape bias_shape{ output_size };
99*c217d954SCole Faust 
100*c217d954SCole Faust     auto input_to_input_weights      = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
101*c217d954SCole Faust     auto input_to_forget_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
102*c217d954SCole Faust     auto input_to_cell_weights       = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
103*c217d954SCole Faust     auto input_to_output_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
104*c217d954SCole Faust     auto recurrent_to_input_weights  = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
105*c217d954SCole Faust     auto recurrent_to_forget_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
106*c217d954SCole Faust     auto recurrent_to_cell_weights   = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
107*c217d954SCole Faust     auto recurrent_to_output_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
108*c217d954SCole Faust     auto input_gate_bias             = create_tensor<Tensor>(bias_shape, DataType::S32);
109*c217d954SCole Faust     auto forget_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
110*c217d954SCole Faust     auto cell_gate_bias              = create_tensor<Tensor>(bias_shape, DataType::S32);
111*c217d954SCole Faust     auto output_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
112*c217d954SCole Faust 
113*c217d954SCole Faust     // LSTM input
114*c217d954SCole Faust     auto input = create_tensor<Tensor>(input_shape, DataType::QASYMM8, 1, qasymm);
115*c217d954SCole Faust 
116*c217d954SCole Faust     // LSTM output state
117*c217d954SCole Faust     auto output_state = create_tensor<Tensor>(output_shape, DataType::QASYMM8, 1, qasymm);
118*c217d954SCole Faust 
119*c217d954SCole Faust     // LSTM cell state
120*c217d954SCole Faust     auto cell_state = create_tensor<Tensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
121*c217d954SCole Faust 
122*c217d954SCole Faust     NELSTMLayerQuantized lstmq;
123*c217d954SCole Faust 
124*c217d954SCole Faust     lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
125*c217d954SCole Faust                     &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
126*c217d954SCole Faust                     &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
127*c217d954SCole Faust 
128*c217d954SCole Faust     input.allocator()->allocate();
129*c217d954SCole Faust     input_to_input_weights.allocator()->allocate();
130*c217d954SCole Faust     input_to_forget_weights.allocator()->allocate();
131*c217d954SCole Faust     input_to_cell_weights.allocator()->allocate();
132*c217d954SCole Faust     input_to_output_weights.allocator()->allocate();
133*c217d954SCole Faust     recurrent_to_input_weights.allocator()->allocate();
134*c217d954SCole Faust     recurrent_to_forget_weights.allocator()->allocate();
135*c217d954SCole Faust     recurrent_to_cell_weights.allocator()->allocate();
136*c217d954SCole Faust     recurrent_to_output_weights.allocator()->allocate();
137*c217d954SCole Faust     input_gate_bias.allocator()->allocate();
138*c217d954SCole Faust     forget_gate_bias.allocator()->allocate();
139*c217d954SCole Faust     cell_gate_bias.allocator()->allocate();
140*c217d954SCole Faust     output_gate_bias.allocator()->allocate();
141*c217d954SCole Faust     cell_state.allocator()->allocate();
142*c217d954SCole Faust     output_state.allocator()->allocate();
143*c217d954SCole Faust 
144*c217d954SCole Faust     // Fill weights and biases
145*c217d954SCole Faust     fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 47,  168,
146*c217d954SCole Faust                                                               66,  239,
147*c217d954SCole Faust                                                                6,   42,
148*c217d954SCole Faust                                                              237,  236 });
149*c217d954SCole Faust 
150*c217d954SCole Faust     fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204,  193,
151*c217d954SCole Faust                                                                 148,  59,
152*c217d954SCole Faust                                                                 113,  17,
153*c217d954SCole Faust                                                                  66, 197 });
154*c217d954SCole Faust 
155*c217d954SCole Faust     fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172,  101,
156*c217d954SCole Faust                                                               184, 209,
157*c217d954SCole Faust                                                               165,  82,
158*c217d954SCole Faust                                                               108, 209 });
159*c217d954SCole Faust 
160*c217d954SCole Faust     fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
161*c217d954SCole Faust                                                                 219, 114,
162*c217d954SCole Faust                                                                 130,  16,
163*c217d954SCole Faust                                                                 163, 222 });
164*c217d954SCole Faust 
165*c217d954SCole Faust     fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168,  7,  95,
166*c217d954SCole Faust                                                                     91, 155, 108, 216,
167*c217d954SCole Faust                                                                    255, 100,  48, 188,
168*c217d954SCole Faust                                                                     58,  37, 186, 147 });
169*c217d954SCole Faust 
170*c217d954SCole Faust     fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> {  46,  58,  47, 170,
171*c217d954SCole Faust                                                                     246,  96,  12,  99,
172*c217d954SCole Faust                                                                      68,  23, 186, 161,
173*c217d954SCole Faust                                                                     237, 164,  89,   6 });
174*c217d954SCole Faust 
175*c217d954SCole Faust     fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234,  99,   71, 206,
176*c217d954SCole Faust                                                                   205, 159,   64, 253,
177*c217d954SCole Faust                                                                   191, 148,  116,   8,
178*c217d954SCole Faust                                                                   209, 136,   59, 138 });
179*c217d954SCole Faust 
180*c217d954SCole Faust     fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> {  23, 241, 137, 36,
181*c217d954SCole Faust                                                                     206,   5, 227, 56,
182*c217d954SCole Faust                                                                     254, 176, 231, 47,
183*c217d954SCole Faust                                                                      18, 201, 161, 11 });
184*c217d954SCole Faust 
185*c217d954SCole Faust     fill_tensor(input_gate_bias, std::vector<int>  {-103038,   30525,  115255, -38154 });
186*c217d954SCole Faust     fill_tensor(forget_gate_bias, std::vector<int> { -23428,  126970,  116806,  46307 });
187*c217d954SCole Faust     fill_tensor(cell_gate_bias, std::vector<int>   { 128006,   69949,  -42808,  42568 });
188*c217d954SCole Faust     fill_tensor(output_gate_bias, std::vector<int> { -67066,  -53607,   47233,  7300  });
189*c217d954SCole Faust 
190*c217d954SCole Faust     SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
191*c217d954SCole Faust 
192*c217d954SCole Faust     // Initialize state
193*c217d954SCole Faust     fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
194*c217d954SCole Faust                                                      128, 128, 128, 128 });
195*c217d954SCole Faust     fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
196*c217d954SCole Faust                                                    0, 0, 0, 0 });
197*c217d954SCole Faust 
198*c217d954SCole Faust     // First input
199*c217d954SCole Faust     fill_tensor(input, std::vector<uint8_t> { 106,  193,
200*c217d954SCole Faust                                               155,  150 });
201*c217d954SCole Faust 
202*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 130,  36, 134,
203*c217d954SCole Faust                                                         128, 131,  35, 133 });
204*c217d954SCole Faust 
205*c217d954SCole Faust     lstmq.run();
206*c217d954SCole Faust     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
207*c217d954SCole Faust 
208*c217d954SCole Faust     // Second input
209*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 129, 12, 137,
210*c217d954SCole Faust                                                         128, 131, 10, 136 });
211*c217d954SCole Faust     lstmq.run();
212*c217d954SCole Faust     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
213*c217d954SCole Faust 
214*c217d954SCole Faust     // Third input
215*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 129, 8, 140,
216*c217d954SCole Faust                                                         128, 130, 6, 138 });
217*c217d954SCole Faust     lstmq.run();
218*c217d954SCole Faust     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
219*c217d954SCole Faust }
220*c217d954SCole Faust 
TEST_CASE(RunLarge,framework::DatasetMode::PRECOMMIT)221*c217d954SCole Faust TEST_CASE(RunLarge, framework::DatasetMode::PRECOMMIT)
222*c217d954SCole Faust {
223*c217d954SCole Faust     const int batch_size  = 16;
224*c217d954SCole Faust     const int input_size  = 8;
225*c217d954SCole Faust     const int output_size = 8;
226*c217d954SCole Faust 
227*c217d954SCole Faust 
228*c217d954SCole Faust     QuantizationInfo qasymm(1.f / 128.f, 128);
229*c217d954SCole Faust     QuantizationInfo qweights(1.f / 128.f, 128);
230*c217d954SCole Faust     QuantizationInfo qsymm_3(8.f / 32768.f, 0);
231*c217d954SCole Faust     QuantizationInfo qsymm_4(16.f / 32768.f, 0);
232*c217d954SCole Faust 
233*c217d954SCole Faust     TensorShape input_shape{ input_size, batch_size };
234*c217d954SCole Faust     TensorShape input_weights_shape{ input_size, output_size };
235*c217d954SCole Faust     TensorShape recurrent_weights_shape{ output_size, output_size };
236*c217d954SCole Faust     TensorShape output_shape{ output_size, batch_size};
237*c217d954SCole Faust     TensorShape bias_shape{ output_size };
238*c217d954SCole Faust 
239*c217d954SCole Faust     auto input_to_input_weights      = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
240*c217d954SCole Faust     auto input_to_forget_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
241*c217d954SCole Faust     auto input_to_cell_weights       = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
242*c217d954SCole Faust     auto input_to_output_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
243*c217d954SCole Faust     auto recurrent_to_input_weights  = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
244*c217d954SCole Faust     auto recurrent_to_forget_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
245*c217d954SCole Faust     auto recurrent_to_cell_weights   = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
246*c217d954SCole Faust     auto recurrent_to_output_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
247*c217d954SCole Faust     auto input_gate_bias             = create_tensor<Tensor>(bias_shape, DataType::S32);
248*c217d954SCole Faust     auto forget_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
249*c217d954SCole Faust     auto cell_gate_bias              = create_tensor<Tensor>(bias_shape, DataType::S32);
250*c217d954SCole Faust     auto output_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
251*c217d954SCole Faust 
252*c217d954SCole Faust     // LSTM input
253*c217d954SCole Faust     auto input = create_tensor<Tensor>(input_shape, DataType::QASYMM8, 1, qasymm);
254*c217d954SCole Faust 
255*c217d954SCole Faust     // LSTM output state
256*c217d954SCole Faust     auto output_state = create_tensor<Tensor>(output_shape, DataType::QASYMM8, 1, qasymm);
257*c217d954SCole Faust 
258*c217d954SCole Faust     // LSTM cell state
259*c217d954SCole Faust     auto cell_state = create_tensor<Tensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
260*c217d954SCole Faust 
261*c217d954SCole Faust     NELSTMLayerQuantized lstmq;
262*c217d954SCole Faust 
263*c217d954SCole Faust     lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
264*c217d954SCole Faust                     &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
265*c217d954SCole Faust                     &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
266*c217d954SCole Faust 
267*c217d954SCole Faust     input.allocator()->allocate();
268*c217d954SCole Faust     input_to_input_weights.allocator()->allocate();
269*c217d954SCole Faust     input_to_forget_weights.allocator()->allocate();
270*c217d954SCole Faust     input_to_cell_weights.allocator()->allocate();
271*c217d954SCole Faust     input_to_output_weights.allocator()->allocate();
272*c217d954SCole Faust     recurrent_to_input_weights.allocator()->allocate();
273*c217d954SCole Faust     recurrent_to_forget_weights.allocator()->allocate();
274*c217d954SCole Faust     recurrent_to_cell_weights.allocator()->allocate();
275*c217d954SCole Faust     recurrent_to_output_weights.allocator()->allocate();
276*c217d954SCole Faust     input_gate_bias.allocator()->allocate();
277*c217d954SCole Faust     forget_gate_bias.allocator()->allocate();
278*c217d954SCole Faust     cell_gate_bias.allocator()->allocate();
279*c217d954SCole Faust     output_gate_bias.allocator()->allocate();
280*c217d954SCole Faust     cell_state.allocator()->allocate();
281*c217d954SCole Faust     output_state.allocator()->allocate();
282*c217d954SCole Faust 
283*c217d954SCole Faust     // Fill weights and biases
284*c217d954SCole Faust     fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 141,  89, 200, 180,  46,  50,  87, 128,
285*c217d954SCole Faust                                                               149, 227, 177, 187, 212, 229,  54, 111,
286*c217d954SCole Faust                                                               131, 116,   3,  58, 196,  26, 131, 255,
287*c217d954SCole Faust                                                                22, 106, 216,  69, 239,  12, 232, 207,
288*c217d954SCole Faust                                                               184,  56, 236, 172,  28, 143, 161, 124,
289*c217d954SCole Faust                                                               255,  33, 197, 122,  47, 197,  26, 229,
290*c217d954SCole Faust                                                                91,  79,  11, 160,  26,  80, 100,  36,
291*c217d954SCole Faust                                                               248, 186,  97,  61, 125,  46,  14, 100, });
292*c217d954SCole Faust 
293*c217d954SCole Faust     fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 237, 165, 141, 249,  72, 116, 36 , 115,
294*c217d954SCole Faust                                                                 234, 213,  85,  84,  59,  62, 150, 246,
295*c217d954SCole Faust                                                                 182, 102, 158, 214, 182, 183,  94,  11,
296*c217d954SCole Faust                                                                 158, 192,  92, 189, 160, 219, 206, 249,
297*c217d954SCole Faust                                                                  88, 213, 193, 244, 151,  72, 129,  49,
298*c217d954SCole Faust                                                                 239,  83, 106,   9, 169, 187, 125, 171,
299*c217d954SCole Faust                                                                  32, 141, 126,  92,  13,  36, 224, 150,
300*c217d954SCole Faust                                                                 187, 250, 178, 169,  89, 214,  91, 173 });
301*c217d954SCole Faust 
302*c217d954SCole Faust     fill_tensor(input_to_cell_weights, std::vector<uint8_t> {  93, 103, 226, 139, 185, 252, 129, 171,
303*c217d954SCole Faust                                                               159,  32,  25, 175, 224, 183, 165,  35,
304*c217d954SCole Faust                                                               207,  69, 238, 228, 149, 214,  79,   6,
305*c217d954SCole Faust                                                                 5,  66, 102,  14,  19, 111,  36, 143,
306*c217d954SCole Faust                                                                22,  85,  13,  78, 236, 121, 122,  77,
307*c217d954SCole Faust                                                               249,  39,  88,  12, 205, 143,  93, 240,
308*c217d954SCole Faust                                                               167,  89, 188,  50,  73,  69, 201, 251,
309*c217d954SCole Faust                                                                59,  32, 203, 184, 139, 191, 199,  74});
310*c217d954SCole Faust 
311*c217d954SCole Faust     fill_tensor(input_to_output_weights, std::vector<uint8_t> { 205,   7,  95, 104, 252, 143, 226,  73,
312*c217d954SCole Faust                                                                 229, 114, 152, 171, 221, 153,  73, 229,
313*c217d954SCole Faust                                                                 153, 165, 223, 239, 100,  38, 172, 211,
314*c217d954SCole Faust                                                                 226, 133, 239, 207, 116, 230, 170, 100,
315*c217d954SCole Faust                                                                 241,  95, 171, 124,  63, 115,  32, 127,
316*c217d954SCole Faust                                                                 141, 239,  53, 193, 201,  53, 104, 178,
317*c217d954SCole Faust                                                                 186, 212, 167, 107, 226, 230,  71, 213,
318*c217d954SCole Faust                                                                 148, 217,  19, 248, 233, 195, 183, 156 });
319*c217d954SCole Faust 
320*c217d954SCole Faust     fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 147, 112, 140, 103,   3, 255,  17,  49,
321*c217d954SCole Faust                                                                     84, 112, 144, 213, 138, 142, 112,  66,
322*c217d954SCole Faust                                                                    117,  30, 101,  35,  25, 132, 211, 229,
323*c217d954SCole Faust                                                                    183, 208, 102,  16,  38,  85, 101, 152,
324*c217d954SCole Faust                                                                    226,  83, 132,  22, 161, 110, 157, 129,
325*c217d954SCole Faust                                                                    184,  63, 168,  42, 220, 126, 209, 157,
326*c217d954SCole Faust                                                                      5,  88, 243,  83, 249,  19, 226, 209,
327*c217d954SCole Faust                                                                    173,  96, 185,  77, 146, 227, 238, 136 });
328*c217d954SCole Faust 
329*c217d954SCole Faust 
330*c217d954SCole Faust     fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> {  52, 132,  92, 200, 213,  32, 213,  37,
331*c217d954SCole Faust                                                                     116, 142, 116, 180,   4, 172, 158, 143,
332*c217d954SCole Faust                                                                     110,  40,  99,  28, 221, 153, 133,   2,
333*c217d954SCole Faust                                                                     247, 144, 198, 100,  20,  15, 221, 196,
334*c217d954SCole Faust                                                                     159, 178, 188, 151, 171,  15,  25, 217,
335*c217d954SCole Faust                                                                     178, 109, 110, 118, 128,  39, 232, 234,
336*c217d954SCole Faust                                                                     184, 214, 177,  13,  56,   6,  28, 252,
337*c217d954SCole Faust                                                                      89, 187, 242,  59, 146, 111, 132, 129});
338*c217d954SCole Faust 
339*c217d954SCole Faust     fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> {  70,  44, 137,  29,  36, 127,   1, 241,
340*c217d954SCole Faust                                                                    26, 241, 142, 114,  67, 181,  49,  57,
341*c217d954SCole Faust                                                                   131, 152, 175,  77,  23,  63,  37, 124,
342*c217d954SCole Faust                                                                   150, 113,  95, 103, 110, 201,  69,  97,
343*c217d954SCole Faust                                                                   196, 242,  62, 214,  66,  19,  45, 135,
344*c217d954SCole Faust                                                                    22, 168, 149, 104,  77, 101,  36,  68,
345*c217d954SCole Faust                                                                   170, 116, 222, 100, 109,   1, 154,  18,
346*c217d954SCole Faust                                                                   133, 215, 105,  93,  31,  57, 231, 112 });
347*c217d954SCole Faust 
348*c217d954SCole Faust 
349*c217d954SCole Faust     fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> { 45 ,  181 ,  220 ,  219 ,   49  ,  63 ,   49  , 129,
350*c217d954SCole Faust                                                                      7 ,  166 ,  104 ,  114 ,   83  ,  40 ,    1  , 195,
351*c217d954SCole Faust                                                                    245 ,  142 ,   82 ,  232 ,  104  , 245 ,   82  , 196,
352*c217d954SCole Faust                                                                    111 ,   56 ,  156 ,    9 ,  141  , 240 ,  180  , 148,
353*c217d954SCole Faust                                                                    247 ,  198 ,  234 ,  137 ,   13  , 210 ,  161  , 192,
354*c217d954SCole Faust                                                                    196 ,   59 ,  233 ,  184 ,  142  , 187 ,  140  , 166,
355*c217d954SCole Faust                                                                      2 ,   95 ,  152 ,   46 ,   71  ,  46 ,  113  ,  32,
356*c217d954SCole Faust                                                                    175 ,  229 ,   86 ,   87 ,   62  ,  93 ,   74  , 130});
357*c217d954SCole Faust 
358*c217d954SCole Faust     fill_tensor(input_gate_bias, std::vector<int>  {  -40040, -106916,  -92315,  -79123,   45160, -17954,   50962, -63758 });
359*c217d954SCole Faust     fill_tensor(forget_gate_bias, std::vector<int> { -128514,    8463,  -57831,  116977,  106547, -28132, -124557,  44941 });
360*c217d954SCole Faust     fill_tensor(cell_gate_bias, std::vector<int>   { 88388  ,  123601, -116148,  -13022,   21619,  48926,   57523,  39332 });
361*c217d954SCole Faust     fill_tensor(output_gate_bias, std::vector<int> {  59485 ,  -33070,   21386, -100633, -115959, 125768,  -56407,  24897 });
362*c217d954SCole Faust 
363*c217d954SCole Faust     SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
364*c217d954SCole Faust 
365*c217d954SCole Faust     // Initialize state
366*c217d954SCole Faust     fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128, 128, 128, 128, 128,
367*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
368*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
369*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
370*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
371*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
372*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
373*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
374*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
375*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
376*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
377*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
378*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
379*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
380*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128,
381*c217d954SCole Faust                                                      128, 128, 128, 128, 128, 128, 128, 128 });
382*c217d954SCole Faust 
383*c217d954SCole Faust     fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0, 0, 0, 0, 0,
384*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
385*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
386*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
387*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
388*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
389*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
390*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
391*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
392*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
393*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
394*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
395*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
396*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
397*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0,
398*c217d954SCole Faust                                                    0, 0, 0, 0, 0, 0, 0, 0});
399*c217d954SCole Faust 
400*c217d954SCole Faust     // First input
401*c217d954SCole Faust     fill_tensor(input, std::vector<uint8_t> { 247,  203, 159, 131, 182, 114, 207, 195,
402*c217d954SCole Faust                                               48 ,  61 , 154,  16,  80, 101, 116, 255,
403*c217d954SCole Faust                                               50 , 115 ,  45, 186,  75, 212,  98,  48,
404*c217d954SCole Faust                                               88 , 146 ,  24, 143, 218, 174, 203, 200,
405*c217d954SCole Faust                                              239 ,  16 ,  66, 136, 234,  54,  94,  51,
406*c217d954SCole Faust                                              101 , 128 , 220, 213, 164,  82, 137, 255,
407*c217d954SCole Faust                                               70 , 165 , 234, 220,  66,  35, 183, 206,
408*c217d954SCole Faust                                               39 ,  57 , 180, 202,  23, 172, 224, 109,
409*c217d954SCole Faust                                              102 , 215 , 186,  82, 215, 147,  85, 187,
410*c217d954SCole Faust                                               96 , 249 ,  59, 116, 150,  44, 167, 128,
411*c217d954SCole Faust                                               34 , 217 , 148, 193, 243,  38, 250, 208,
412*c217d954SCole Faust                                              112 , 130 , 208,  29,  16, 122,  20,  92,
413*c217d954SCole Faust                                               24 ,  72 , 104,  29, 150, 233, 151,  19,
414*c217d954SCole Faust                                              158 , 192 , 254,  70,  73, 142, 106, 152,
415*c217d954SCole Faust                                                3 ,  61 ,  24, 135, 212,   9,  80, 234,
416*c217d954SCole Faust                                              147 , 246 ,  83, 249,  49,  14,  68,  50});
417*c217d954SCole Faust 
418*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> {131, 128,  128,  128,  128,  180,  129,  133,
419*c217d954SCole Faust                                                        136, 128,  126,  128,  128,  173,  135,  130,
420*c217d954SCole Faust                                                        160, 128,  128,  128,  128,  138,  132,  129,
421*c217d954SCole Faust                                                        131, 128,  127,  128,  128,  169,  129,  131,
422*c217d954SCole Faust                                                        133, 128,  128,  128,  128,  182,  130,  129,
423*c217d954SCole Faust                                                        131, 128,  128,  128,  128,  163,  129,  130,
424*c217d954SCole Faust                                                        131, 128,  128,  128,  128,  149,  132,  129,
425*c217d954SCole Faust                                                        143, 128,  127,  128,  128,  150,  134,  131,
426*c217d954SCole Faust                                                        134, 128,  128,  128,  128,  167,  130,  130,
427*c217d954SCole Faust                                                        131, 128,  128,  128,  128,  152,  132,  129,
428*c217d954SCole Faust                                                        128, 128,  128,  128,  128,  169,  130,  130,
429*c217d954SCole Faust                                                        173, 128,  128,  128,  128,  148,  139,  130,
430*c217d954SCole Faust                                                        152, 128,  128,  128,  128,  168,  139,  132,
431*c217d954SCole Faust                                                        147, 128,  128,  128,  128,  161,  131,  132,
432*c217d954SCole Faust                                                        130, 128,  128,  128,  128,  159,  134,  128,
433*c217d954SCole Faust                                                        140, 128,  128,  128,  128,  133,  132,  128 });
434*c217d954SCole Faust 
435*c217d954SCole Faust     lstmq.run();
436*c217d954SCole Faust     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
437*c217d954SCole Faust 
438*c217d954SCole Faust     // Second input
439*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 130,   128,   128,   128,   128,   205,   129,   137,
440*c217d954SCole Faust                                                         135,   128,   127,   128,   128,   190,   137,   132,
441*c217d954SCole Faust                                                         160,   128,   128,   128,   128,   142,   133,   131,
442*c217d954SCole Faust                                                         130,   128,   128,   128,   128,   185,   129,   133,
443*c217d954SCole Faust                                                         132,   128,   128,   128,   128,   198,   131,   130,
444*c217d954SCole Faust                                                         130,   128,   128,   128,   128,   178,   130,   131,
445*c217d954SCole Faust                                                         131,   128,   128,   128,   128,   158,   132,   131,
446*c217d954SCole Faust                                                         142,   128,   127,   128,   128,   158,   135,   134,
447*c217d954SCole Faust                                                         133,   128,   128,   128,   128,   178,   131,   132,
448*c217d954SCole Faust                                                         131,   128,   128,   128,   128,   160,   132,   130,
449*c217d954SCole Faust                                                         128,   128,   128,   128,   128,   190,   131,   131,
450*c217d954SCole Faust                                                         170,   128,   128,   128,   128,   157,   142,   131,
451*c217d954SCole Faust                                                         149,   128,   128,   128,   128,   178,   142,   135,
452*c217d954SCole Faust                                                         145,   128,   128,   128,   129,   173,   132,   135,
453*c217d954SCole Faust                                                         129,   128,   128,   128,   128,   171,   134,   129,
454*c217d954SCole Faust                                                         140,   128,   128,   128,   128,   135,   132,   129});
455*c217d954SCole Faust     lstmq.run();
456*c217d954SCole Faust     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
457*c217d954SCole Faust }
458*c217d954SCole Faust TEST_SUITE_END() // MultSmallerEq1
459*c217d954SCole Faust 
TEST_SUITE(MultGreater1)460*c217d954SCole Faust TEST_SUITE(MultGreater1)
461*c217d954SCole Faust TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
462*c217d954SCole Faust {
463*c217d954SCole Faust     //Input sequence length is 1
464*c217d954SCole Faust     const int batch_size  = 2;
465*c217d954SCole Faust     const int input_size  = 2;
466*c217d954SCole Faust     const int output_size = 4;
467*c217d954SCole Faust 
468*c217d954SCole Faust     QuantizationInfo qasymm(1.f / 128.f, 128);
469*c217d954SCole Faust     QuantizationInfo qweights(1.f / 16.f, 16);
470*c217d954SCole Faust     QuantizationInfo qsymm_3(8.f / 32768.f, 0);
471*c217d954SCole Faust     QuantizationInfo qsymm_4(16.f / 32768.f, 0);
472*c217d954SCole Faust 
473*c217d954SCole Faust     TensorShape input_shape{ input_size, batch_size };
474*c217d954SCole Faust     TensorShape input_weights_shape{ input_size, output_size };
475*c217d954SCole Faust     TensorShape recurrent_weights_shape{ output_size, output_size };
476*c217d954SCole Faust     TensorShape output_shape{ output_size, batch_size};
477*c217d954SCole Faust     TensorShape bias_shape{ output_size };
478*c217d954SCole Faust 
479*c217d954SCole Faust     auto input_to_input_weights      = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
480*c217d954SCole Faust     auto input_to_forget_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
481*c217d954SCole Faust     auto input_to_cell_weights       = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
482*c217d954SCole Faust     auto input_to_output_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
483*c217d954SCole Faust     auto recurrent_to_input_weights  = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
484*c217d954SCole Faust     auto recurrent_to_forget_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
485*c217d954SCole Faust     auto recurrent_to_cell_weights   = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
486*c217d954SCole Faust     auto recurrent_to_output_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
487*c217d954SCole Faust     auto input_gate_bias             = create_tensor<Tensor>(bias_shape, DataType::S32);
488*c217d954SCole Faust     auto forget_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
489*c217d954SCole Faust     auto cell_gate_bias              = create_tensor<Tensor>(bias_shape, DataType::S32);
490*c217d954SCole Faust     auto output_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
491*c217d954SCole Faust 
492*c217d954SCole Faust     // LSTM input
493*c217d954SCole Faust     auto input = create_tensor<Tensor>(input_shape, DataType::QASYMM8, 1, qasymm);
494*c217d954SCole Faust 
495*c217d954SCole Faust     // LSTM output state
496*c217d954SCole Faust     auto output_state = create_tensor<Tensor>(output_shape, DataType::QASYMM8, 1, qasymm);
497*c217d954SCole Faust 
498*c217d954SCole Faust     // LSTM cell state
499*c217d954SCole Faust     auto cell_state = create_tensor<Tensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
500*c217d954SCole Faust 
501*c217d954SCole Faust     NELSTMLayerQuantized lstmq;
502*c217d954SCole Faust 
503*c217d954SCole Faust     lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
504*c217d954SCole Faust                     &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
505*c217d954SCole Faust                     &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
506*c217d954SCole Faust 
507*c217d954SCole Faust     input.allocator()->allocate();
508*c217d954SCole Faust     input_to_input_weights.allocator()->allocate();
509*c217d954SCole Faust     input_to_forget_weights.allocator()->allocate();
510*c217d954SCole Faust     input_to_cell_weights.allocator()->allocate();
511*c217d954SCole Faust     input_to_output_weights.allocator()->allocate();
512*c217d954SCole Faust     recurrent_to_input_weights.allocator()->allocate();
513*c217d954SCole Faust     recurrent_to_forget_weights.allocator()->allocate();
514*c217d954SCole Faust     recurrent_to_cell_weights.allocator()->allocate();
515*c217d954SCole Faust     recurrent_to_output_weights.allocator()->allocate();
516*c217d954SCole Faust     input_gate_bias.allocator()->allocate();
517*c217d954SCole Faust     forget_gate_bias.allocator()->allocate();
518*c217d954SCole Faust     cell_gate_bias.allocator()->allocate();
519*c217d954SCole Faust     output_gate_bias.allocator()->allocate();
520*c217d954SCole Faust     cell_state.allocator()->allocate();
521*c217d954SCole Faust     output_state.allocator()->allocate();
522*c217d954SCole Faust 
523*c217d954SCole Faust     // Fill weights and biases
524*c217d954SCole Faust     fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 122,  130,
525*c217d954SCole Faust                                                               124,  134,
526*c217d954SCole Faust                                                                120,   122,
527*c217d954SCole Faust                                                              134,  134 });
528*c217d954SCole Faust 
529*c217d954SCole Faust     fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204,  193,
530*c217d954SCole Faust                                                                 148,  59,
531*c217d954SCole Faust                                                                 113,  17,
532*c217d954SCole Faust                                                                  66, 197 });
533*c217d954SCole Faust 
534*c217d954SCole Faust     fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172,  101,
535*c217d954SCole Faust                                                               184, 209,
536*c217d954SCole Faust                                                               165,  82,
537*c217d954SCole Faust                                                               108, 209 });
538*c217d954SCole Faust 
539*c217d954SCole Faust     fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
540*c217d954SCole Faust                                                                 219, 114,
541*c217d954SCole Faust                                                                 130,  16,
542*c217d954SCole Faust                                                                 163, 222 });
543*c217d954SCole Faust 
544*c217d954SCole Faust     fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168,  7,  95,
545*c217d954SCole Faust                                                                     91, 155, 108, 216,
546*c217d954SCole Faust                                                                    255, 100,  48, 188,
547*c217d954SCole Faust                                                                     58,  37, 186, 147 });
548*c217d954SCole Faust 
549*c217d954SCole Faust     fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> {  46,  58,  47, 170,
550*c217d954SCole Faust                                                                     246,  96,  12,  99,
551*c217d954SCole Faust                                                                      68,  23, 186, 161,
552*c217d954SCole Faust                                                                     237, 164,  89,   6 });
553*c217d954SCole Faust 
554*c217d954SCole Faust     fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234,  99,   71, 206,
555*c217d954SCole Faust                                                                   205, 159,   64, 253,
556*c217d954SCole Faust                                                                   191, 148,  116,   8,
557*c217d954SCole Faust                                                                   209, 136,   59, 138 });
558*c217d954SCole Faust 
559*c217d954SCole Faust     fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> {  23, 241, 137, 36,
560*c217d954SCole Faust                                                                     206,   5, 227, 56,
561*c217d954SCole Faust                                                                     254, 176, 231, 47,
562*c217d954SCole Faust                                                                      18, 201, 161, 11 });
563*c217d954SCole Faust 
564*c217d954SCole Faust     fill_tensor(input_gate_bias, std::vector<int>  {-103038,   30525,  115255, -38154 });
565*c217d954SCole Faust     fill_tensor(forget_gate_bias, std::vector<int> { -23428,  126970,  116806,  46307 });
566*c217d954SCole Faust     fill_tensor(cell_gate_bias, std::vector<int>   { 128006,   69949,  -42808,  42568 });
567*c217d954SCole Faust     fill_tensor(output_gate_bias, std::vector<int> { -67066,  -53607,   47233,  7300  });
568*c217d954SCole Faust 
569*c217d954SCole Faust     SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
570*c217d954SCole Faust 
571*c217d954SCole Faust     // Initialize state
572*c217d954SCole Faust     fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
573*c217d954SCole Faust                                                      128, 128, 128, 128 });
574*c217d954SCole Faust     fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
575*c217d954SCole Faust                                                    0, 0, 0, 0 });
576*c217d954SCole Faust 
577*c217d954SCole Faust     // First input
578*c217d954SCole Faust     fill_tensor(input, std::vector<uint8_t> { 106,  193,
579*c217d954SCole Faust                                               155,  150 });
580*c217d954SCole Faust 
581*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 128,  31, 128,
582*c217d954SCole Faust                                                         128, 128,  31, 128 });
583*c217d954SCole Faust 
584*c217d954SCole Faust     lstmq.run();
585*c217d954SCole Faust     validate(Accessor(output_state), expected_output);
586*c217d954SCole Faust 
587*c217d954SCole Faust     // Second input
588*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 5, 128,
589*c217d954SCole Faust                                                         128, 128, 5, 128 });
590*c217d954SCole Faust     lstmq.run();
591*c217d954SCole Faust     validate(Accessor(output_state), expected_output);
592*c217d954SCole Faust 
593*c217d954SCole Faust     // Third input
594*c217d954SCole Faust     fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 1, 128,
595*c217d954SCole Faust                                                         128, 128, 1, 128, });
596*c217d954SCole Faust     lstmq.run();
597*c217d954SCole Faust     validate(Accessor(output_state), expected_output);
598*c217d954SCole Faust }
599*c217d954SCole Faust TEST_SUITE_END() // MultGreater1
600*c217d954SCole Faust TEST_SUITE_END() // IntegrationTestCase
601*c217d954SCole Faust // clang-format on
602*c217d954SCole Faust // *INDENT-ON*
603*c217d954SCole Faust 
604*c217d954SCole Faust TEST_SUITE_END() // LSTMLayerQuantized
605*c217d954SCole Faust TEST_SUITE_END() // Neon
606*c217d954SCole Faust } // namespace validation
607*c217d954SCole Faust } // namespace test
608*c217d954SCole Faust } // namespace arm_compute
609