1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "UnidirectionalSequenceLstmTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmTest(std::vector<armnn::BackendId> & backends)17*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmTest(std::vector<armnn::BackendId>& backends)
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
20*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
21*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
22*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
23*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
24*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12,
27*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
28*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToInputWeights = { -0.49536117f, -0.0556083915f, -0.102400711f,
29*89c4ff92SAndroid Build Coastguard Worker -0.117484632f, 0.3298470976f, -0.1179017122f,
30*89c4ff92SAndroid Build Coastguard Worker 0.214305695f, 0.42135173085f, 0.003878414626f,
31*89c4ff92SAndroid Build Coastguard Worker -0.348303917f, -0.1881275477f, 0.0343011027f };
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f,
34*89c4ff92SAndroid Build Coastguard Worker -0.3810434485f, 0.268383264f, -0.009807467424f,
35*89c4ff92SAndroid Build Coastguard Worker -0.3522925403f, -0.24275735512f, -0.28344226125f,
36*89c4ff92SAndroid Build Coastguard Worker 0.13512269116f, -0.4932442977f, -0.10039821991f };
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f,
39*89c4ff92SAndroid Build Coastguard Worker 0.386399507f, -0.259465157985f, -0.16545993089f,
40*89c4ff92SAndroid Build Coastguard Worker -0.4230232555f, 0.341664791103f, -0.18127849691f,
41*89c4ff92SAndroid Build Coastguard Worker -0.2277662414f, -0.55275535589f, 0.34184026718f };
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f,
44*89c4ff92SAndroid Build Coastguard Worker 0.53969591851f, 0.23393625035f, -0.27140527306f,
45*89c4ff92SAndroid Build Coastguard Worker 0.50009280443f, 0.07511717046f, 0.3998299249f,
46*89c4ff92SAndroid Build Coastguard Worker -0.51717478049f, 0.1889653282f, -0.367323637f };
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16,
49*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
50*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToInputWeights = { -0.128009796112f, 0.1995525098f, -0.07745539397f, 0.1558421701f,
51*89c4ff92SAndroid Build Coastguard Worker -0.265254765766f, -0.38837709614f, -0.05636804124f, 0.4259087456f,
52*89c4ff92SAndroid Build Coastguard Worker 0.17628988623f, 0.3877420127f, 0.53300309181f, -0.0959980934f,
53*89c4ff92SAndroid Build Coastguard Worker 0.00302857416f, 0.3266998827f, -0.142509296562f, -0.04433270756f };
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f,
56*89c4ff92SAndroid Build Coastguard Worker -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f,
57*89c4ff92SAndroid Build Coastguard Worker -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f,
58*89c4ff92SAndroid Build Coastguard Worker -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f };
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f,
61*89c4ff92SAndroid Build Coastguard Worker -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f,
62*89c4ff92SAndroid Build Coastguard Worker 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f,
63*89c4ff92SAndroid Build Coastguard Worker 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f };
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f,
66*89c4ff92SAndroid Build Coastguard Worker -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f,
67*89c4ff92SAndroid Build Coastguard Worker 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f,
68*89c4ff92SAndroid Build Coastguard Worker -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f };
69*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
70*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
71*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToInputWeights;
72*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = false;
73*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToForgetWeights;
74*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = false;
75*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToOutputWeights;
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
78*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = {0., 0., 0., 0.};
79*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = {1., 1., 1., 1.};
80*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = {0., 0., 0., 0.};
81*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = {0., 0., 0., 0.};
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
84*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionWeights;
85*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
86*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
89*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
90*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
91*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
92*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
93*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
94*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
95*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 2., 3., 4., 5., 4.,
98*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
99*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2. };
100*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0714901f, -0.162117f, -0.175168f, -0.0232934f,
101*89c4ff92SAndroid Build Coastguard Worker -0.168107f, -0.414129f, -0.549875f, -0.00803579f,
102*89c4ff92SAndroid Build Coastguard Worker -0.0668735f, 0.204078f, -0.42765f, -0.0312321f,
103*89c4ff92SAndroid Build Coastguard Worker -0.120003f, -0.0941918f, -0.456391f, -0.0287019f,
104*89c4ff92SAndroid Build Coastguard Worker -0.0342921f, 0.20824f, -0.656989f, -0.00415265f,
105*89c4ff92SAndroid Build Coastguard Worker -0.10493f, 0.14211f, -0.583478f, -0.0329754f };
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
108*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
109*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
110*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<float>(backends,
113*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
114*89c4ff92SAndroid Build Coastguard Worker batchSize,
115*89c4ff92SAndroid Build Coastguard Worker timeSize,
116*89c4ff92SAndroid Build Coastguard Worker inputSize,
117*89c4ff92SAndroid Build Coastguard Worker outputSize,
118*89c4ff92SAndroid Build Coastguard Worker numUnits,
119*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
120*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
121*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
122*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
123*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
124*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
125*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
126*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
127*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
128*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
129*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
130*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
131*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
132*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
133*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
134*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
135*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
136*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
137*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
138*89c4ff92SAndroid Build Coastguard Worker cellBias,
139*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
140*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
141*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
142*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
143*89c4ff92SAndroid Build Coastguard Worker projectionBias,
144*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
145*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
146*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
147*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
148*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
149*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
150*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
151*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
152*89c4ff92SAndroid Build Coastguard Worker inputValues,
153*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
154*89c4ff92SAndroid Build Coastguard Worker activationFunction,
155*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
156*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
157*89c4ff92SAndroid Build Coastguard Worker isTimeMajor);
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmTimeMajorTest(std::vector<armnn::BackendId> & backends)160*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmTimeMajorTest(std::vector<armnn::BackendId>& backends)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
163*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
164*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
165*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
166*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
167*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape = {timeSize, batchSize, inputSize};
170*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> cellStateInTensorInfo = {batchSize, numUnits};
171*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputStateInTensorInfo = {batchSize, outputSize};
172*89c4ff92SAndroid Build Coastguard Worker
173*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputTensorInfo = {timeSize, batchSize, outputSize};
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12
176*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
177*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToInputWeights = { 0.27277296781539917f, 0.3813590407371521f, -0.394489049911499f,
178*89c4ff92SAndroid Build Coastguard Worker 0.2782636880874634f, -0.3793870210647583f, -0.018918335437774658f,
179*89c4ff92SAndroid Build Coastguard Worker 0.2724653482437134f, -0.19314253330230713f, -0.2947450876235962f,
180*89c4ff92SAndroid Build Coastguard Worker -0.30253493785858154f, 0.4241350293159485f, -0.22560018301010132f };
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToForgetWeights = { -0.2667974531650543f, -0.05505800247192383f, -0.20932340621948242f,
183*89c4ff92SAndroid Build Coastguard Worker -0.14345619082450867f, 0.09666192531585693f, -0.2604355812072754f,
184*89c4ff92SAndroid Build Coastguard Worker -0.2681812047958374f, -0.3314584493637085f, 0.4485899806022644f,
185*89c4ff92SAndroid Build Coastguard Worker -0.23467743396759033f, 0.5072842240333557f, -0.4192768931388855f };
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToCellWeights = { -0.15782442688941956f, -0.027530014514923096f, 0.4789854884147644f,
188*89c4ff92SAndroid Build Coastguard Worker 0.23227906227111816f, 0.28259342908859253f, -0.030095696449279785f,
189*89c4ff92SAndroid Build Coastguard Worker 0.10071521997451782f, -0.08535495400428772f, 0.18563997745513916f,
190*89c4ff92SAndroid Build Coastguard Worker -0.3049069046974182f, -0.478048175573349f, 0.025234103202819824f };
191*89c4ff92SAndroid Build Coastguard Worker
192*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToOutputWeights = { -0.04584759473800659f, -0.2716066539287567f, 0.012970447540283203f,
193*89c4ff92SAndroid Build Coastguard Worker -0.4729190170764923f, -0.37422770261764526f, 0.49352723360061646f,
194*89c4ff92SAndroid Build Coastguard Worker 0.3163864016532898f, -0.436781644821167f, -0.33074596524238586f,
195*89c4ff92SAndroid Build Coastguard Worker -0.32885751128196716f, -0.40959352254867554f, -0.2124689817428589f };
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16
198*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
199*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToInputWeights = { 0.23788475990f, -0.24948765337f, 0.50044941902f, 0.14431896805f,
200*89c4ff92SAndroid Build Coastguard Worker -0.115940228137f, -0.717082679f, -0.17208620906f, 0.17850610617f,
201*89c4ff92SAndroid Build Coastguard Worker -0.16702319684f, -0.11384502053f, -0.309785276245f, -0.3316611672f,
202*89c4ff92SAndroid Build Coastguard Worker 0.52380162477f, -0.06839632987f, -0.391478359627f, -0.10756178963f };
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToForgetWeights = { 0.11383482068f, 0.1676601767f, -0.08550968004f, 0.03399394089f,
205*89c4ff92SAndroid Build Coastguard Worker 0.08042152225f, -0.2133381964f, 0.05182432704f, 0.38161808255f,
206*89c4ff92SAndroid Build Coastguard Worker -0.5018365979f, -0.08043262364f, 0.07894329014f, -0.07547105155f,
207*89c4ff92SAndroid Build Coastguard Worker 0.12047368288f, 0.2986997961f, 0.0485043078f, -0.13372567296f };
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToCellWeights = { 0.0433832928545f, 0.07587072294f, -0.120520234107f, 0.604576051f,
210*89c4ff92SAndroid Build Coastguard Worker -0.434353142986f, 0.009314475068f, 0.005085289478f, 0.08488202038f,
211*89c4ff92SAndroid Build Coastguard Worker -0.00025437487886f, 0.15245915082f, -0.1936587542f, 0.004754020f,
212*89c4ff92SAndroid Build Coastguard Worker -0.1582719236f, 0.3307867646f, 0.0236605107784f, 0.307716339826f };
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToOutputWeights = { -0.079031050201f, 0.041414566286f, -0.583727357285f, 0.1025384515f,
215*89c4ff92SAndroid Build Coastguard Worker -0.172372072937f, 0.09214124082f, 0.178184121827f, -0.2439443916f,
216*89c4ff92SAndroid Build Coastguard Worker 0.104485116899f, 0.2600405514f, 0.064414866268f, 0.24141204357f,
217*89c4ff92SAndroid Build Coastguard Worker 0.281875759363f, -0.14234502664f, 0.15126448862f, -0.24421440064f };
218*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
219*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
220*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToInputWeights;
221*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = false;
222*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToForgetWeights;
223*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = false;
224*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToOutputWeights;
225*89c4ff92SAndroid Build Coastguard Worker
226*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
227*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = {0., 0., 0., 0.};
228*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = {1., 1., 1., 1.};
229*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = {0., 0., 0., 0.};
230*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = {0., 0., 0., 0.};
231*89c4ff92SAndroid Build Coastguard Worker
232*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
233*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionWeights;
234*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
235*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
238*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
239*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
240*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
241*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
242*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
243*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
244*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
245*89c4ff92SAndroid Build Coastguard Worker
246*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 2., 3., 4., 5., 4.,
247*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
248*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2. };
249*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 0.135658f, 0.124673f, 0.021209f, -0.0530204f,
250*89c4ff92SAndroid Build Coastguard Worker 0.106138f, 0.0404792f, 0.0151644f, -0.00675166f,
251*89c4ff92SAndroid Build Coastguard Worker -0.0128514f, 0.0644884f, 0.0709072f, -0.0454045f,
252*89c4ff92SAndroid Build Coastguard Worker 0.162886f, 0.166494f, 0.0277046f, -0.0369807f,
253*89c4ff92SAndroid Build Coastguard Worker 0.111716f, 0.043119f, 0.0762981f, -0.0122854f,
254*89c4ff92SAndroid Build Coastguard Worker 0.104397f, 0.2144f, 0.119192f, -0.0839058f };
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
257*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
258*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
259*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = true;
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<float>(backends,
262*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
263*89c4ff92SAndroid Build Coastguard Worker batchSize,
264*89c4ff92SAndroid Build Coastguard Worker timeSize,
265*89c4ff92SAndroid Build Coastguard Worker inputSize,
266*89c4ff92SAndroid Build Coastguard Worker outputSize,
267*89c4ff92SAndroid Build Coastguard Worker numUnits,
268*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
269*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
270*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
271*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
272*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
273*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
274*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
275*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
276*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
277*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
278*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
279*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
280*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
281*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
282*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
283*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
284*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
285*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
286*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
287*89c4ff92SAndroid Build Coastguard Worker cellBias,
288*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
289*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
290*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
291*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
292*89c4ff92SAndroid Build Coastguard Worker projectionBias,
293*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
294*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
295*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
296*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
297*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
298*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
299*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
300*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
301*89c4ff92SAndroid Build Coastguard Worker inputValues,
302*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
303*89c4ff92SAndroid Build Coastguard Worker activationFunction,
304*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
305*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
306*89c4ff92SAndroid Build Coastguard Worker isTimeMajor);
307*89c4ff92SAndroid Build Coastguard Worker }
308*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionTest(std::vector<armnn::BackendId> & backends)309*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionTest(std::vector<armnn::BackendId>& backends)
310*89c4ff92SAndroid Build Coastguard Worker {
311*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 2;
312*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 3;
313*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 4;
314*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 5;
315*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = 6;
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape = {batchSize, timeSize, inputSize};
318*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> cellStateInTensorInfo = {batchSize, numUnits};
319*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputStateInTensorInfo = {batchSize, outputSize};
320*89c4ff92SAndroid Build Coastguard Worker
321*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputTensorInfo = {batchSize, timeSize, outputSize};
322*89c4ff92SAndroid Build Coastguard Worker
323*89c4ff92SAndroid Build Coastguard Worker //tensorInfoInputSize,
324*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
325*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToInputWeights = { 0.021393683f, 0.06124551f, 0.046905167f, -0.014657677f,
326*89c4ff92SAndroid Build Coastguard Worker -0.03149463f, 0.09171803f, 0.14647801f, 0.10797193f,
327*89c4ff92SAndroid Build Coastguard Worker -0.0057968358f, 0.0019193048f, -0.2726754f, 0.10154029f,
328*89c4ff92SAndroid Build Coastguard Worker -0.018539885f, 0.080349885f, -0.10262385f, -0.022599787f,
329*89c4ff92SAndroid Build Coastguard Worker -0.09121155f, -0.008675967f, -0.045206103f, -0.0821282f,
330*89c4ff92SAndroid Build Coastguard Worker -0.008045952f, 0.015478081f, 0.055217247f, 0.038719587f };
331*89c4ff92SAndroid Build Coastguard Worker
332*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToForgetWeights = { -0.0018401089f, -0.004852237f, 0.03698424f, 0.014181704f,
333*89c4ff92SAndroid Build Coastguard Worker 0.028273236f, -0.016726194f, -0.05249759f, -0.10204261f,
334*89c4ff92SAndroid Build Coastguard Worker 0.00861066f, -0.040979505f, -0.009899187f, 0.01923892f,
335*89c4ff92SAndroid Build Coastguard Worker -0.028177269f, -0.08535103f, -0.14585495f, 0.10662567f,
336*89c4ff92SAndroid Build Coastguard Worker -0.01909731f, -0.017883534f, -0.0047269356f, -0.045103323f,
337*89c4ff92SAndroid Build Coastguard Worker 0.0030784295f, 0.076784775f, 0.07463696f, 0.094531395f};
338*89c4ff92SAndroid Build Coastguard Worker
339*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
340*89c4ff92SAndroid Build Coastguard Worker -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
341*89c4ff92SAndroid Build Coastguard Worker -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
342*89c4ff92SAndroid Build Coastguard Worker -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
343*89c4ff92SAndroid Build Coastguard Worker -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
344*89c4ff92SAndroid Build Coastguard Worker 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f };
345*89c4ff92SAndroid Build Coastguard Worker
346*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToOutputWeights = { -0.0998932f, -0.07201956f, -0.052803773f, -0.15629593f,
347*89c4ff92SAndroid Build Coastguard Worker -0.15001918f, -0.07650751f, 0.02359855f, -0.075155355f,
348*89c4ff92SAndroid Build Coastguard Worker -0.08037709f, -0.15093534f, 0.029517552f, -0.04751393f,
349*89c4ff92SAndroid Build Coastguard Worker 0.010350531f, -0.02664851f, -0.016839722f, -0.023121163f,
350*89c4ff92SAndroid Build Coastguard Worker 0.0077019283f, 0.012851257f, -0.05040649f, -0.0129761f,
351*89c4ff92SAndroid Build Coastguard Worker -0.021737747f, -0.038305793f, -0.06870586f, -0.01481247f };
352*89c4ff92SAndroid Build Coastguard Worker
353*89c4ff92SAndroid Build Coastguard Worker //tensorInfoOutputSize,
354*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
355*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
356*89c4ff92SAndroid Build Coastguard Worker -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
357*89c4ff92SAndroid Build Coastguard Worker -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
358*89c4ff92SAndroid Build Coastguard Worker -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
359*89c4ff92SAndroid Build Coastguard Worker 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
360*89c4ff92SAndroid Build Coastguard Worker 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
361*89c4ff92SAndroid Build Coastguard Worker -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
362*89c4ff92SAndroid Build Coastguard Worker 0.14283475f, -0.07390571f };
363*89c4ff92SAndroid Build Coastguard Worker
364*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToForgetWeights = { -0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
365*89c4ff92SAndroid Build Coastguard Worker 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
366*89c4ff92SAndroid Build Coastguard Worker -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
367*89c4ff92SAndroid Build Coastguard Worker 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
368*89c4ff92SAndroid Build Coastguard Worker 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
369*89c4ff92SAndroid Build Coastguard Worker -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
370*89c4ff92SAndroid Build Coastguard Worker -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
371*89c4ff92SAndroid Build Coastguard Worker 0.061878487f, -0.04729229f };
372*89c4ff92SAndroid Build Coastguard Worker
373*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
374*89c4ff92SAndroid Build Coastguard Worker 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
375*89c4ff92SAndroid Build Coastguard Worker 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
376*89c4ff92SAndroid Build Coastguard Worker -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
377*89c4ff92SAndroid Build Coastguard Worker 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
378*89c4ff92SAndroid Build Coastguard Worker 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
379*89c4ff92SAndroid Build Coastguard Worker -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
380*89c4ff92SAndroid Build Coastguard Worker -0.019443132f, -0.030755889f };
381*89c4ff92SAndroid Build Coastguard Worker
382*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,
383*89c4ff92SAndroid Build Coastguard Worker -0.045984812f,-0.01255415f, -0.0026479573f,
384*89c4ff92SAndroid Build Coastguard Worker -0.08196161f, -0.054914974f, -0.0046604523f,
385*89c4ff92SAndroid Build Coastguard Worker -0.029587349f, -0.044576716f, -0.07480124f,
386*89c4ff92SAndroid Build Coastguard Worker -0.082868785f, 0.023254942f, 0.027502948f,
387*89c4ff92SAndroid Build Coastguard Worker -0.0039728214f, -0.08683098f, -0.08116779f,
388*89c4ff92SAndroid Build Coastguard Worker -0.014675607f, -0.037924774f, -0.023314456f,
389*89c4ff92SAndroid Build Coastguard Worker -0.007401714f, -0.09255757f, 0.029460307f,
390*89c4ff92SAndroid Build Coastguard Worker -0.08829125f, -0.005139627f, -0.08989442f,
391*89c4ff92SAndroid Build Coastguard Worker -0.0555066f, 0.13596267f, 0.025062224f };
392*89c4ff92SAndroid Build Coastguard Worker // tensorInfoNumUnits
393*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = true;
394*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToInputWeights = { 0.040369894f, 0.030746894f, 0.24704495f,
395*89c4ff92SAndroid Build Coastguard Worker 0.018586371f, -0.037586458f, -0.15312155f };
396*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
397*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToForgetWeights = { -0.01998659f, -0.15568835f, -0.24248174f,
398*89c4ff92SAndroid Build Coastguard Worker -0.012770197f, 0.041331276f, -0.072311886f };
399*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
400*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f,
401*89c4ff92SAndroid Build Coastguard Worker 0.002913762f, 0.17764764f, -0.5495371f };
402*89c4ff92SAndroid Build Coastguard Worker
403*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
404*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0.02234832f, 0.14757581f, 0.18176508f,
405*89c4ff92SAndroid Build Coastguard Worker 0.10380666f, 0.053110216f, -0.06928846f };
406*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 0.035185695f, -0.042891346f, -0.03032477f,
407*89c4ff92SAndroid Build Coastguard Worker 0.23027696f, 0.11098921f, 0.08989442f };
408*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f,
409*89c4ff92SAndroid Build Coastguard Worker 0.033463873f, -0.1483596f, 0.029460307f };
410*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0.046159424f, -0.0012809046f, 0.03563469f,
411*89c4ff92SAndroid Build Coastguard Worker 0.12648113f, 0.027195795f, 0.35373217f };
412*89c4ff92SAndroid Build Coastguard Worker
413*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = true;
414*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionWeights = { -0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
415*89c4ff92SAndroid Build Coastguard Worker 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
416*89c4ff92SAndroid Build Coastguard Worker -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
417*89c4ff92SAndroid Build Coastguard Worker -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
418*89c4ff92SAndroid Build Coastguard Worker 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
419*89c4ff92SAndroid Build Coastguard Worker 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f };
420*89c4ff92SAndroid Build Coastguard Worker
421*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = true;
422*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias(outputSize, 0.f);
423*89c4ff92SAndroid Build Coastguard Worker
424*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
425*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
426*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
427*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
428*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
429*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
430*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
431*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
432*89c4ff92SAndroid Build Coastguard Worker
433*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 2., 3., 4., 5., 4.,
434*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
435*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2.,
436*89c4ff92SAndroid Build Coastguard Worker 1., 2., 3., 4., 5., 4.};
437*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0135612f, -0.0263441f, 0.0314008f, -0.00883455f, 0.00763052f,
438*89c4ff92SAndroid Build Coastguard Worker -0.00126877f, -0.0292959f, 0.0449957f, -0.00976195f, -0.00492338f,
439*89c4ff92SAndroid Build Coastguard Worker -0.0175702f, -0.0431753f, 0.0597117f, -0.0169154f, 0.0142087f,
440*89c4ff92SAndroid Build Coastguard Worker 0.00472515f, -0.0196355f, 0.0342524f, -0.00407936f, -0.0253189f,
441*89c4ff92SAndroid Build Coastguard Worker -0.00512944f, -0.0293754f, 0.0512771f, -0.0151874f, -0.0246433f,
442*89c4ff92SAndroid Build Coastguard Worker -0.00744986f, -0.0345103f, 0.0450666f, -0.00944991f, 0.0126895f };
443*89c4ff92SAndroid Build Coastguard Worker
444*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
445*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
446*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
447*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
448*89c4ff92SAndroid Build Coastguard Worker
449*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<float>(backends,
450*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
451*89c4ff92SAndroid Build Coastguard Worker batchSize,
452*89c4ff92SAndroid Build Coastguard Worker timeSize,
453*89c4ff92SAndroid Build Coastguard Worker inputSize,
454*89c4ff92SAndroid Build Coastguard Worker outputSize,
455*89c4ff92SAndroid Build Coastguard Worker numUnits,
456*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
457*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
458*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
459*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
460*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
461*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
462*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
463*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
464*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
465*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
466*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
467*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
468*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
469*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
470*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
471*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
472*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
473*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
474*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
475*89c4ff92SAndroid Build Coastguard Worker cellBias,
476*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
477*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
478*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
479*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
480*89c4ff92SAndroid Build Coastguard Worker projectionBias,
481*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
482*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
483*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
484*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
485*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
486*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
487*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
488*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
489*89c4ff92SAndroid Build Coastguard Worker inputValues,
490*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
491*89c4ff92SAndroid Build Coastguard Worker activationFunction,
492*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
493*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
494*89c4ff92SAndroid Build Coastguard Worker isTimeMajor);
495*89c4ff92SAndroid Build Coastguard Worker }
496*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest(std::vector<armnn::BackendId> & backends)497*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest(std::vector<armnn::BackendId>& backends)
498*89c4ff92SAndroid Build Coastguard Worker {
499*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
500*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
501*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
502*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
503*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
504*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
505*89c4ff92SAndroid Build Coastguard Worker
506*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12
507*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = false;
508*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToInputWeights{};
509*89c4ff92SAndroid Build Coastguard Worker
510*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f,
511*89c4ff92SAndroid Build Coastguard Worker -0.3810434485f, 0.268383264f, -0.009807467424f,
512*89c4ff92SAndroid Build Coastguard Worker -0.3522925403f, -0.24275735512f, -0.28344226125f,
513*89c4ff92SAndroid Build Coastguard Worker 0.13512269116f, -0.4932442977f, -0.10039821991f };
514*89c4ff92SAndroid Build Coastguard Worker
515*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f,
516*89c4ff92SAndroid Build Coastguard Worker 0.386399507f, -0.259465157985f, -0.16545993089f,
517*89c4ff92SAndroid Build Coastguard Worker -0.4230232555f, 0.341664791103f, -0.18127849691f,
518*89c4ff92SAndroid Build Coastguard Worker -0.2277662414f, -0.55275535589f, 0.34184026718f };
519*89c4ff92SAndroid Build Coastguard Worker
520*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f,
521*89c4ff92SAndroid Build Coastguard Worker 0.53969591851f, 0.23393625035f, -0.27140527306f,
522*89c4ff92SAndroid Build Coastguard Worker 0.50009280443f, 0.07511717046f, 0.3998299249f,
523*89c4ff92SAndroid Build Coastguard Worker -0.51717478049f, 0.1889653282f, -0.367323637f };
524*89c4ff92SAndroid Build Coastguard Worker
525*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16
526*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = false;
527*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToInputWeights{};
528*89c4ff92SAndroid Build Coastguard Worker
529*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f,
530*89c4ff92SAndroid Build Coastguard Worker -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f,
531*89c4ff92SAndroid Build Coastguard Worker -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f,
532*89c4ff92SAndroid Build Coastguard Worker -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f };
533*89c4ff92SAndroid Build Coastguard Worker
534*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f,
535*89c4ff92SAndroid Build Coastguard Worker -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f,
536*89c4ff92SAndroid Build Coastguard Worker 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f,
537*89c4ff92SAndroid Build Coastguard Worker 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f };
538*89c4ff92SAndroid Build Coastguard Worker
539*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f,
540*89c4ff92SAndroid Build Coastguard Worker -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f,
541*89c4ff92SAndroid Build Coastguard Worker 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f,
542*89c4ff92SAndroid Build Coastguard Worker -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f };
543*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
544*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
545*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToInputWeights;
546*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
547*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
548*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
549*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
550*89c4ff92SAndroid Build Coastguard Worker
551*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = false;
552*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias;
553*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = {1., 1., 1., 1.};
554*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = {0., 0., 0., 0.};
555*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = {0., 0., 0., 0.};
556*89c4ff92SAndroid Build Coastguard Worker
557*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
558*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionWeights;
559*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
560*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
561*89c4ff92SAndroid Build Coastguard Worker
562*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
563*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
564*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
565*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
566*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
567*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
568*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
569*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
570*89c4ff92SAndroid Build Coastguard Worker
571*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 2., 3., 4., 5., 4.,
572*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
573*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2. };
574*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0129257f, -0.070531f, -0.153508f, -0.0392391f,
575*89c4ff92SAndroid Build Coastguard Worker -0.0300169f, -0.195717f, -0.528679f, -0.0818106f,
576*89c4ff92SAndroid Build Coastguard Worker -0.0332748f, 0.155429f, -0.353966f, -0.0801505f,
577*89c4ff92SAndroid Build Coastguard Worker -0.032312f, -0.0407911f, -0.435053f, -0.0932317f,
578*89c4ff92SAndroid Build Coastguard Worker -0.0108233f, 0.165584f, -0.640424f, -0.0447535f,
579*89c4ff92SAndroid Build Coastguard Worker -0.031675f, 0.125987f, -0.526695f, -0.110093f };
580*89c4ff92SAndroid Build Coastguard Worker
581*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
582*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
583*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
584*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
585*89c4ff92SAndroid Build Coastguard Worker
586*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<float>(backends,
587*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
588*89c4ff92SAndroid Build Coastguard Worker batchSize,
589*89c4ff92SAndroid Build Coastguard Worker timeSize,
590*89c4ff92SAndroid Build Coastguard Worker inputSize,
591*89c4ff92SAndroid Build Coastguard Worker outputSize,
592*89c4ff92SAndroid Build Coastguard Worker numUnits,
593*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
594*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
595*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
596*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
597*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
598*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
599*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
600*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
601*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
602*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
603*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
604*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
605*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
606*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
607*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
608*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
609*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
610*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
611*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
612*89c4ff92SAndroid Build Coastguard Worker cellBias,
613*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
614*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
615*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
616*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
617*89c4ff92SAndroid Build Coastguard Worker projectionBias,
618*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
619*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
620*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
621*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
622*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
623*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
624*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
625*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
626*89c4ff92SAndroid Build Coastguard Worker inputValues,
627*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
628*89c4ff92SAndroid Build Coastguard Worker activationFunction,
629*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
630*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
631*89c4ff92SAndroid Build Coastguard Worker isTimeMajor);
632*89c4ff92SAndroid Build Coastguard Worker }
633*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNormTest(std::vector<armnn::BackendId> & backends)634*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNormTest(
635*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId>& backends)
636*89c4ff92SAndroid Build Coastguard Worker {
637*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
638*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
639*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
640*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
641*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = 5;
642*89c4ff92SAndroid Build Coastguard Worker
643*89c4ff92SAndroid Build Coastguard Worker //tensorInfo15
644*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
645*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToInputWeights = { -0.49536117f, -0.0556083915f, -0.102400711f,
646*89c4ff92SAndroid Build Coastguard Worker -0.117484632f, 0.3298470976f, -0.1179017122f,
647*89c4ff92SAndroid Build Coastguard Worker 0.214305695f, 0.42135173085f, 0.003878414626f,
648*89c4ff92SAndroid Build Coastguard Worker -0.348303917f, -0.1881275477f, 0.0343011027f,
649*89c4ff92SAndroid Build Coastguard Worker -0.38837709614f, -0.05636804124f, 0.4259087456f};
650*89c4ff92SAndroid Build Coastguard Worker
651*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToForgetWeights = { 0.2415594226f, 0.15400093799f, 0.4566498398f,
652*89c4ff92SAndroid Build Coastguard Worker -0.3810434485f, 0.268383264f, -0.009807467424f,
653*89c4ff92SAndroid Build Coastguard Worker -0.3522925403f, -0.24275735512f, -0.28344226125f,
654*89c4ff92SAndroid Build Coastguard Worker 0.13512269116f, -0.4932442977f, -0.10039821991f,
655*89c4ff92SAndroid Build Coastguard Worker 0.2726137042f, 0.09216640889f, -0.06551410215f};
656*89c4ff92SAndroid Build Coastguard Worker
657*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToCellWeights = { -0.2504855627f, 0.184490025045f, -0.2480507493f,
658*89c4ff92SAndroid Build Coastguard Worker 0.386399507f, -0.259465157985f, -0.16545993089f,
659*89c4ff92SAndroid Build Coastguard Worker -0.4230232555f, 0.341664791103f, -0.18127849691f,
660*89c4ff92SAndroid Build Coastguard Worker -0.2277662414f, -0.55275535589f, 0.34184026718f,
661*89c4ff92SAndroid Build Coastguard Worker 0.3954237699f, -0.19407111404f, 0.30412107706f};
662*89c4ff92SAndroid Build Coastguard Worker
663*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputToOutputWeights = { 0.2303854227f, 0.5218806862f, -0.4865379333f,
664*89c4ff92SAndroid Build Coastguard Worker 0.53969591851f, 0.23393625035f, -0.27140527306f,
665*89c4ff92SAndroid Build Coastguard Worker 0.50009280443f, 0.07511717046f, 0.3998299249f,
666*89c4ff92SAndroid Build Coastguard Worker -0.51717478049f, 0.1889653282f, -0.367323637f,
667*89c4ff92SAndroid Build Coastguard Worker -0.12584099173f, -0.12319286912f, 0.2407919466f};
668*89c4ff92SAndroid Build Coastguard Worker
669*89c4ff92SAndroid Build Coastguard Worker //tensorInfo20
670*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
671*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToInputWeights = { -0.128009796112f, 0.1995525098f, -0.07745539397f, 0.1558421701f,
672*89c4ff92SAndroid Build Coastguard Worker -0.265254765766f, -0.38837709614f, -0.05636804124f, 0.4259087456f,
673*89c4ff92SAndroid Build Coastguard Worker 0.17628988623f, 0.3877420127f, 0.53300309181f, -0.0959980934f,
674*89c4ff92SAndroid Build Coastguard Worker 0.00302857416f, 0.3266998827f, -0.142509296562f, -0.04433270756f,
675*89c4ff92SAndroid Build Coastguard Worker 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f };
676*89c4ff92SAndroid Build Coastguard Worker
677*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToForgetWeights = { -0.09499983487f, -0.08814888417f, -0.04834804721f, 0.1516668247f,
678*89c4ff92SAndroid Build Coastguard Worker -0.3967529535f, -0.06463699788f, 0.4952811002f, 0.003274492938f,
679*89c4ff92SAndroid Build Coastguard Worker -0.0968840941f, 0.17928104102f, 0.0031281141592f, -0.3387276584f,
680*89c4ff92SAndroid Build Coastguard Worker -0.3587934076f, 0.06705895066f, 0.22463923692f, 0.1961955726f,
681*89c4ff92SAndroid Build Coastguard Worker 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f };
682*89c4ff92SAndroid Build Coastguard Worker
683*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToCellWeights = { -0.21938985582f, -0.3023648226f, -0.1170005202f, -0.3509177422f,
684*89c4ff92SAndroid Build Coastguard Worker -0.4286288613f, 0.2726137042f, 0.09216640889f, -0.06551410215f,
685*89c4ff92SAndroid Build Coastguard Worker 0.20453298098f, 0.2393476665f, 0.11846517771f, 0.2630801796f,
686*89c4ff92SAndroid Build Coastguard Worker 0.3954237699f, -0.19407111404f, 0.30412107706f, -0.27342408554f,
687*89c4ff92SAndroid Build Coastguard Worker 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f };
688*89c4ff92SAndroid Build Coastguard Worker
689*89c4ff92SAndroid Build Coastguard Worker std::vector<float> recurrentToOutputWeights = { -0.32921677827f, 0.32624614238f, -0.1388191282f, -0.17879831790f,
690*89c4ff92SAndroid Build Coastguard Worker -0.15185534954f, -0.16918526583f, -0.10087361183f, -0.5436913968f,
691*89c4ff92SAndroid Build Coastguard Worker 0.016758225858f, 0.30454617738f, -0.41493862867f, -0.005565764375f,
692*89c4ff92SAndroid Build Coastguard Worker -0.12584099173f, -0.12319286912f, 0.2407919466f, -0.08879069983f,
693*89c4ff92SAndroid Build Coastguard Worker 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f };
694*89c4ff92SAndroid Build Coastguard Worker // tensorInfo5
695*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = true;
696*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToInputWeights = { 0.05f, 0.1f, 0.25f, 0.15f, -0.02f };
697*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
698*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToForgetWeights = { -0.02f, -0.15f, -0.25f, -0.03f, 0.15f };
699*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
700*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellToOutputWeights = { 0.1f, -0.1f, -0.5f, 0.05f, 0.01f };
701*89c4ff92SAndroid Build Coastguard Worker
702*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
703*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0.03f, 0.15f, 0.22f, 0.38f, 0.05f };
704*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 0.1f, -0.3f, -0.2f, 0.1f, 0.4f };
705*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { -0.05f, 0.72f, 0.25f, 0.08f, 0.1f };
706*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0.05f, -0.01f, 0.2f, 0.1f, -0.2f };
707*89c4ff92SAndroid Build Coastguard Worker
708*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = true;
709*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionWeights = { -0.1f, 0.2f, 0.01f, -0.2f,
710*89c4ff92SAndroid Build Coastguard Worker 0.1f, 0.5f, 0.3f, 0.08f,
711*89c4ff92SAndroid Build Coastguard Worker 0.07f, 0.2f, -0.4f, 0.2f,
712*89c4ff92SAndroid Build Coastguard Worker 0.5f, -0.4f, 0.3f, -0.2f,
713*89c4ff92SAndroid Build Coastguard Worker 0.3f, 0.08f, -0.07f, 0.2f}; //{outputSize, numUnits}
714*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = true;
715*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias(outputSize, 0.f);;
716*89c4ff92SAndroid Build Coastguard Worker
717*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = true;
718*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights = { 0.1f, 0.2f, 0.3f, 0.5f, 0.8f };
719*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = true;
720*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights = { 0.1f, 0.2f, 0.3f, 0.5f, 0.2f };
721*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = true;
722*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights = { 0.7f, 0.2f, 0.3f, 0.8f, 0.5f };
723*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = true;
724*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights = { 0.6f, 0.2f, 0.2f, 0.5f, 0.1f };
725*89c4ff92SAndroid Build Coastguard Worker
726*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 2., 3., 4., 5., 4.,
727*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
728*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2. };
729*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 0.0642256f, 0.0343966f, 0.184122f, 0.114717f,
730*89c4ff92SAndroid Build Coastguard Worker 0.11458f, 0.0407109f, 0.300327f, 0.174301f,
731*89c4ff92SAndroid Build Coastguard Worker 0.0864761f, 0.0362912f, 0.178635f, 0.115689f,
732*89c4ff92SAndroid Build Coastguard Worker 0.108008f, 0.0386623f, 0.273471f, 0.167115f,
733*89c4ff92SAndroid Build Coastguard Worker 0.0859545f, 0.0331481f, 0.186051f, 0.11888f,
734*89c4ff92SAndroid Build Coastguard Worker 0.106649f, 0.0276847f, 0.229863f, 0.166958f };
735*89c4ff92SAndroid Build Coastguard Worker
736*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
737*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
738*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
739*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
740*89c4ff92SAndroid Build Coastguard Worker
741*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<float>(backends,
742*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
743*89c4ff92SAndroid Build Coastguard Worker batchSize,
744*89c4ff92SAndroid Build Coastguard Worker timeSize,
745*89c4ff92SAndroid Build Coastguard Worker inputSize,
746*89c4ff92SAndroid Build Coastguard Worker outputSize,
747*89c4ff92SAndroid Build Coastguard Worker numUnits,
748*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
749*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
750*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
751*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
752*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
753*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
754*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
755*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
756*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
757*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
758*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
759*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
760*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
761*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
762*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
763*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
764*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
765*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
766*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
767*89c4ff92SAndroid Build Coastguard Worker cellBias,
768*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
769*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
770*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
771*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
772*89c4ff92SAndroid Build Coastguard Worker projectionBias,
773*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
774*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
775*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
776*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
777*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
778*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
779*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
780*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
781*89c4ff92SAndroid Build Coastguard Worker inputValues,
782*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
783*89c4ff92SAndroid Build Coastguard Worker activationFunction,
784*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
785*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
786*89c4ff92SAndroid Build Coastguard Worker isTimeMajor);
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmInt8Test(std::vector<armnn::BackendId> & backends)789*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmInt8Test(std::vector<armnn::BackendId>& backends)
790*89c4ff92SAndroid Build Coastguard Worker {
791*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
792*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
793*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
794*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
795*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
796*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
797*89c4ff92SAndroid Build Coastguard Worker
798*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12
799*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
800*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToInputWeights = { -4, -1, -1, -2, 3, -2, 2, 4, 1, -4, -2, 3 };
801*89c4ff92SAndroid Build Coastguard Worker
802*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToForgetWeights = { 2, 1, 4, -4, 3, -1, -3, -2, -3, 1, -4, -1 };
803*89c4ff92SAndroid Build Coastguard Worker
804*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToCellWeights = { -2, 1, -2, 4, -3, -2, -4, 3, -2, -2, -6, 3 };
805*89c4ff92SAndroid Build Coastguard Worker
806*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToOutputWeights = { 2, 5, -4, 5, 2, -3, 5, 7, 3, -5, 1, -4 };
807*89c4ff92SAndroid Build Coastguard Worker
808*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16
809*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
810*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToInputWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -1, -1 };
811*89c4ff92SAndroid Build Coastguard Worker
812*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToForgetWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -2, -1 };
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToCellWeights = { -2, -3, -1, -3, -4, 2, 1, -1, 2, 2, 1, 2, 3, -2, 3, -3 };
815*89c4ff92SAndroid Build Coastguard Worker
816*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToOutputWeights = { -3, 3, -1, -2, -2, -2, -1, -5, 1, 3, -4, -1, -1, -1, 2, -1 };
817*89c4ff92SAndroid Build Coastguard Worker
818*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
819*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
820*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToInputWeights;
821*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = false;
822*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToForgetWeights;
823*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = false;
824*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToOutputWeights;
825*89c4ff92SAndroid Build Coastguard Worker
826*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
827*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0., 0., 0., 0. };
828*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 1., 1., 1., 1. };
829*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { 0., 0., 0., 0. };
830*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0., 0., 0., 0. };
831*89c4ff92SAndroid Build Coastguard Worker
832*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
833*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> projectionWeights;
834*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
835*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
836*89c4ff92SAndroid Build Coastguard Worker
837*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
838*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
839*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
840*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
841*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
842*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
843*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
844*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
845*89c4ff92SAndroid Build Coastguard Worker
846*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.4f,
847*89c4ff92SAndroid Build Coastguard Worker 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.4f,
848*89c4ff92SAndroid Build Coastguard Worker 0.5f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f };
849*89c4ff92SAndroid Build Coastguard Worker
850*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0142517f, -0.0198845f, -0.0120569f, -0.0116868f,
851*89c4ff92SAndroid Build Coastguard Worker -0.0350714f, -0.0343202f, -0.047504f, -0.0569789f,
852*89c4ff92SAndroid Build Coastguard Worker -0.0146346f, 0.0106663f, -0.0247238f, -0.0319502f,
853*89c4ff92SAndroid Build Coastguard Worker -0.0294759f, -0.0129935f, -0.0444175f, -0.0444354f,
854*89c4ff92SAndroid Build Coastguard Worker -0.0280855f, 0.00545101f, -0.051422f, -0.0463838f,
855*89c4ff92SAndroid Build Coastguard Worker -0.0310702f, 0.00915739f, -0.0625207f, -0.0482648f };
856*89c4ff92SAndroid Build Coastguard Worker
857*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
858*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
859*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
860*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
861*89c4ff92SAndroid Build Coastguard Worker
862*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<int8_t>(backends,
863*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
864*89c4ff92SAndroid Build Coastguard Worker batchSize,
865*89c4ff92SAndroid Build Coastguard Worker timeSize,
866*89c4ff92SAndroid Build Coastguard Worker inputSize,
867*89c4ff92SAndroid Build Coastguard Worker outputSize,
868*89c4ff92SAndroid Build Coastguard Worker numUnits,
869*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
870*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
871*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
872*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
873*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
874*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
875*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
876*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
877*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
878*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
879*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
880*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
881*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
882*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
883*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
884*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
885*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
886*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
887*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
888*89c4ff92SAndroid Build Coastguard Worker cellBias,
889*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
890*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
891*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
892*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
893*89c4ff92SAndroid Build Coastguard Worker projectionBias,
894*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
895*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
896*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
897*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
898*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
899*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
900*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
901*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
902*89c4ff92SAndroid Build Coastguard Worker inputValues,
903*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
904*89c4ff92SAndroid Build Coastguard Worker activationFunction,
905*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
906*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
907*89c4ff92SAndroid Build Coastguard Worker isTimeMajor,
908*89c4ff92SAndroid Build Coastguard Worker 0.1f);
909*89c4ff92SAndroid Build Coastguard Worker }
910*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmInt8TimeMajorTest(std::vector<armnn::BackendId> & backends)911*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmInt8TimeMajorTest(std::vector<armnn::BackendId>& backends)
912*89c4ff92SAndroid Build Coastguard Worker {
913*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
914*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
915*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
916*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
917*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
918*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
919*89c4ff92SAndroid Build Coastguard Worker
920*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12
921*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
922*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToInputWeights = { -4, -1, -1, -2, 3, -2, 2, 4, 1, -4, -2, 3 };
923*89c4ff92SAndroid Build Coastguard Worker
924*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToForgetWeights = { 2, 1, 4, -4, 3, -1, -3, -2, -3, 1, -4, -1 };
925*89c4ff92SAndroid Build Coastguard Worker
926*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToCellWeights = { -2, 1, -2, 4, -3, -2, -4, 3, -2, -2, -6, 3 };
927*89c4ff92SAndroid Build Coastguard Worker
928*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToOutputWeights = { 2, 5, -4, 5, 2, -3, 5, 7, 3, -5, 1, -4 };
929*89c4ff92SAndroid Build Coastguard Worker
930*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16
931*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
932*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToInputWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -1, -1 };
933*89c4ff92SAndroid Build Coastguard Worker
934*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToForgetWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -2, -1 };
935*89c4ff92SAndroid Build Coastguard Worker
936*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToCellWeights = { -2, -3, -1, -3, -4, 2, 1, -1, 2, 2, 1, 2, 3, -2, 3, -3 };
937*89c4ff92SAndroid Build Coastguard Worker
938*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToOutputWeights = { -3, 3, -1, -2, -2, -2, -1, -5, 1, 3, -4, -1, -1, -1, 2, -1 };
939*89c4ff92SAndroid Build Coastguard Worker
940*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
941*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
942*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToInputWeights;
943*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = false;
944*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToForgetWeights;
945*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = false;
946*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToOutputWeights;
947*89c4ff92SAndroid Build Coastguard Worker
948*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
949*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0., 0., 0., 0. };
950*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 1., 1., 1., 1. };
951*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { 0., 0., 0., 0. };
952*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0., 0., 0., 0. };
953*89c4ff92SAndroid Build Coastguard Worker
954*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
955*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> projectionWeights;
956*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
957*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
958*89c4ff92SAndroid Build Coastguard Worker
959*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
960*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
961*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
962*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
963*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
964*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
965*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
966*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
967*89c4ff92SAndroid Build Coastguard Worker
968*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.4f,
969*89c4ff92SAndroid Build Coastguard Worker 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.4f,
970*89c4ff92SAndroid Build Coastguard Worker 0.5f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f };
971*89c4ff92SAndroid Build Coastguard Worker
972*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0142517f, -0.0198845f, -0.0120122f, -0.0116868f,
973*89c4ff92SAndroid Build Coastguard Worker -0.0261295f, -0.0188487f, -0.0345463f, -0.049733f,
974*89c4ff92SAndroid Build Coastguard Worker -0.0146346f, 0.0106663f, -0.0247238f, -0.0319502f,
975*89c4ff92SAndroid Build Coastguard Worker -0.0291863f, -0.0369402f, -0.0354071f, -0.0296529f,
976*89c4ff92SAndroid Build Coastguard Worker -0.0419539f, -0.00617731f, -0.0814796f, -0.0804005f,
977*89c4ff92SAndroid Build Coastguard Worker -0.0244737f, 0.0119905f, -0.0457527f, -0.0331862f };
978*89c4ff92SAndroid Build Coastguard Worker
979*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
980*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
981*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
982*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = true;
983*89c4ff92SAndroid Build Coastguard Worker
984*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<int8_t>(backends,
985*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
986*89c4ff92SAndroid Build Coastguard Worker batchSize,
987*89c4ff92SAndroid Build Coastguard Worker timeSize,
988*89c4ff92SAndroid Build Coastguard Worker inputSize,
989*89c4ff92SAndroid Build Coastguard Worker outputSize,
990*89c4ff92SAndroid Build Coastguard Worker numUnits,
991*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
992*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
993*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
994*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
995*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
996*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
997*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
998*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
999*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
1000*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
1001*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
1002*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
1003*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
1004*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
1005*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
1006*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
1007*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
1008*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
1009*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
1010*89c4ff92SAndroid Build Coastguard Worker cellBias,
1011*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
1012*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
1013*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
1014*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
1015*89c4ff92SAndroid Build Coastguard Worker projectionBias,
1016*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
1017*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
1018*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
1019*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
1020*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
1021*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
1022*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
1023*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
1024*89c4ff92SAndroid Build Coastguard Worker inputValues,
1025*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
1026*89c4ff92SAndroid Build Coastguard Worker activationFunction,
1027*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
1028*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
1029*89c4ff92SAndroid Build Coastguard Worker isTimeMajor,
1030*89c4ff92SAndroid Build Coastguard Worker 0.1);
1031*89c4ff92SAndroid Build Coastguard Worker }
1032*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionTest(std::vector<armnn::BackendId> & backends)1033*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionTest(std::vector<armnn::BackendId>& backends)
1034*89c4ff92SAndroid Build Coastguard Worker {
1035*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
1036*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
1037*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
1038*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
1039*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = 4;
1040*89c4ff92SAndroid Build Coastguard Worker
1041*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
1042*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToInputWeights = { -4, -1, -1, -2, 3, -2, 2, 4, 1, -4, -2, 3 };
1043*89c4ff92SAndroid Build Coastguard Worker
1044*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToForgetWeights = { 2, 1, 4, -4, 3, -1, -3, -2, -3, 1, -4, -1 };
1045*89c4ff92SAndroid Build Coastguard Worker
1046*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToCellWeights = { -2, 1, -2, 4, -3, -2, -4, 3, -2, -2, -6, 3 };
1047*89c4ff92SAndroid Build Coastguard Worker
1048*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToOutputWeights = { 2, 5, -4, 5, 2, -3, 5, 7, 3, -5, 1, -4 };
1049*89c4ff92SAndroid Build Coastguard Worker
1050*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16
1051*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
1052*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToInputWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -1, -1 };
1053*89c4ff92SAndroid Build Coastguard Worker
1054*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToForgetWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -2, -1 };
1055*89c4ff92SAndroid Build Coastguard Worker
1056*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToCellWeights = { -2, -3, -1, -3, -4, 2, 1, -1, 2, 2, 1, 2, 3, -2, 3, -3 };
1057*89c4ff92SAndroid Build Coastguard Worker
1058*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToOutputWeights = { -3, 3, -1, -2, -2, -2, -1, -5, 1, 3, -4, -1, -1, -1, 2, -1 };
1059*89c4ff92SAndroid Build Coastguard Worker
1060*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
1061*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = true;
1062*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToInputWeights = { 5, 10, 25, 15 };
1063*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
1064*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToForgetWeights = { -5, 15, 25, 3 };
1065*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
1066*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToOutputWeights = { 10, -10, -5, 50 };
1067*89c4ff92SAndroid Build Coastguard Worker
1068*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
1069*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f};
1070*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 0.035185695f, -0.042891346f, -0.3032477f, 0.23027696f};
1071*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { -0.124379363f, 0.55531194f, 0.23377132f, 0.033463873f };
1072*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0.046159424f, -0.12809046f, 0.03563469f, 0.12648113f };
1073*89c4ff92SAndroid Build Coastguard Worker
1074*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = true;
1075*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> projectionWeights = { -25, 51, 3, -5, 25, 127, 77, 20, 18, 51, -10, 51, -25, 88, 77, -13 };
1076*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = true;
1077*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias(outputSize, 0.f);
1078*89c4ff92SAndroid Build Coastguard Worker
1079*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
1080*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
1081*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
1082*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
1083*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
1084*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
1085*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
1086*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
1087*89c4ff92SAndroid Build Coastguard Worker
1088*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.4f,
1089*89c4ff92SAndroid Build Coastguard Worker 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.4f,
1090*89c4ff92SAndroid Build Coastguard Worker 0.5f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f };
1091*89c4ff92SAndroid Build Coastguard Worker
1092*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 0.612103f, 1.56788f, 0.31966f, 1.42956f,
1093*89c4ff92SAndroid Build Coastguard Worker 0.909718f, 3.07916f, -0.560586f, 3.8907f,
1094*89c4ff92SAndroid Build Coastguard Worker 0.753671f, 1.77485f, 0.365122f, 1.60077f,
1095*89c4ff92SAndroid Build Coastguard Worker 0.812644f, 2.79092f, -0.605396f, 3.61742f,
1096*89c4ff92SAndroid Build Coastguard Worker 0.791857f, 1.64353f, 0.316588f, 1.55192f,
1097*89c4ff92SAndroid Build Coastguard Worker 0.807265f, 2.47012f, -0.539598f, 3.25654f };
1098*89c4ff92SAndroid Build Coastguard Worker
1099*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
1100*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
1101*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
1102*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
1103*89c4ff92SAndroid Build Coastguard Worker
1104*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<int8_t>(backends,
1105*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
1106*89c4ff92SAndroid Build Coastguard Worker batchSize,
1107*89c4ff92SAndroid Build Coastguard Worker timeSize,
1108*89c4ff92SAndroid Build Coastguard Worker inputSize,
1109*89c4ff92SAndroid Build Coastguard Worker outputSize,
1110*89c4ff92SAndroid Build Coastguard Worker numUnits,
1111*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
1112*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
1113*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
1114*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
1115*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
1116*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
1117*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
1118*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
1119*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
1120*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
1121*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
1122*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
1123*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
1124*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
1125*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
1126*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
1127*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
1128*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
1129*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
1130*89c4ff92SAndroid Build Coastguard Worker cellBias,
1131*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
1132*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
1133*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
1134*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
1135*89c4ff92SAndroid Build Coastguard Worker projectionBias,
1136*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
1137*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
1138*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
1139*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
1140*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
1141*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
1142*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
1143*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
1144*89c4ff92SAndroid Build Coastguard Worker inputValues,
1145*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
1146*89c4ff92SAndroid Build Coastguard Worker activationFunction,
1147*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
1148*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
1149*89c4ff92SAndroid Build Coastguard Worker isTimeMajor,
1150*89c4ff92SAndroid Build Coastguard Worker 0.1f);
1151*89c4ff92SAndroid Build Coastguard Worker }
1152*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest(std::vector<armnn::BackendId> & backends)1153*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest(std::vector<armnn::BackendId>& backends)
1154*89c4ff92SAndroid Build Coastguard Worker {
1155*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
1156*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
1157*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
1158*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
1159*89c4ff92SAndroid Build Coastguard Worker // cellSize and outputSize have the same size when there is no projection.
1160*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = outputSize;
1161*89c4ff92SAndroid Build Coastguard Worker
1162*89c4ff92SAndroid Build Coastguard Worker //tensorInfo12,
1163*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = false;
1164*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToInputWeights;
1165*89c4ff92SAndroid Build Coastguard Worker
1166*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToForgetWeights = { 2, 1, 4, -4, 3, -1, -3, -2, -3, 1, -4, -1 };
1167*89c4ff92SAndroid Build Coastguard Worker
1168*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToCellWeights = { -2, 1, -2, 4, -3, -2, -4, 3, -2, -2, -6, 3 };
1169*89c4ff92SAndroid Build Coastguard Worker
1170*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToOutputWeights = { 2, 5, -4, 5, 2, -3, 5, 7, 3, -5, 1, -4 };
1171*89c4ff92SAndroid Build Coastguard Worker
1172*89c4ff92SAndroid Build Coastguard Worker //tensorInfo16,
1173*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = false;
1174*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToInputWeights;
1175*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToForgetWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3, 5, -1, 1, 3, -2, -1 };
1176*89c4ff92SAndroid Build Coastguard Worker
1177*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToCellWeights = { -2, -3, -1, -3, -4, 2, 1, -1, 2, 2, 1, 2, 3, -2, 3, -3 };
1178*89c4ff92SAndroid Build Coastguard Worker
1179*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToOutputWeights = { -3, 3, -1, -2, -2, -2, -1, -5, 1, 3, -4, -1, -1, -1, 2, -1 };
1180*89c4ff92SAndroid Build Coastguard Worker
1181*89c4ff92SAndroid Build Coastguard Worker // tensorInfo4
1182*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = false;
1183*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToInputWeights;
1184*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
1185*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToForgetWeights = { 47, -52, -24, 31 };
1186*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
1187*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToOutputWeights = { -17, 82, 85, -77 };
1188*89c4ff92SAndroid Build Coastguard Worker
1189*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = false;
1190*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias;
1191*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 1., 1., 1., 1. };
1192*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { 0., 0., 0., 0. };
1193*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0., 0., 0., 0. };
1194*89c4ff92SAndroid Build Coastguard Worker
1195*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = false;
1196*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> projectionWeights;
1197*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = false;
1198*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias;
1199*89c4ff92SAndroid Build Coastguard Worker
1200*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = false;
1201*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights;
1202*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = false;
1203*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights;
1204*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = false;
1205*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights;
1206*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = false;
1207*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights;
1208*89c4ff92SAndroid Build Coastguard Worker
1209*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.4f,
1210*89c4ff92SAndroid Build Coastguard Worker 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.4f,
1211*89c4ff92SAndroid Build Coastguard Worker 0.5f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f };
1212*89c4ff92SAndroid Build Coastguard Worker
1213*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { -0.0072104f, -0.00991171f, -0.00650478f, -0.00713055f,
1214*89c4ff92SAndroid Build Coastguard Worker -0.0191782f, -0.0161269f, -0.0233683f, -0.054299f,
1215*89c4ff92SAndroid Build Coastguard Worker -0.00783725f, 0.00635271f, -0.0126718f, -0.022613f,
1216*89c4ff92SAndroid Build Coastguard Worker -0.0161351f, -0.00775868f, -0.021054f, -0.0339778f,
1217*89c4ff92SAndroid Build Coastguard Worker -0.0146392f, 0.00330261f, -0.0258733f, -0.0407797f,
1218*89c4ff92SAndroid Build Coastguard Worker -0.0174297f, 0.0050105f, -0.0266275f, -0.0362564f };
1219*89c4ff92SAndroid Build Coastguard Worker
1220*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
1221*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
1222*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
1223*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
1224*89c4ff92SAndroid Build Coastguard Worker
1225*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<int8_t>(backends,
1226*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
1227*89c4ff92SAndroid Build Coastguard Worker batchSize,
1228*89c4ff92SAndroid Build Coastguard Worker timeSize,
1229*89c4ff92SAndroid Build Coastguard Worker inputSize,
1230*89c4ff92SAndroid Build Coastguard Worker outputSize,
1231*89c4ff92SAndroid Build Coastguard Worker numUnits,
1232*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
1233*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
1234*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
1235*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
1236*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
1237*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
1238*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
1239*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
1240*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
1241*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
1242*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
1243*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
1244*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
1245*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
1246*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
1247*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
1248*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
1249*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
1250*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
1251*89c4ff92SAndroid Build Coastguard Worker cellBias,
1252*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
1253*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
1254*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
1255*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
1256*89c4ff92SAndroid Build Coastguard Worker projectionBias,
1257*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
1258*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
1259*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
1260*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
1261*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
1262*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
1263*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
1264*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
1265*89c4ff92SAndroid Build Coastguard Worker inputValues,
1266*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
1267*89c4ff92SAndroid Build Coastguard Worker activationFunction,
1268*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
1269*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
1270*89c4ff92SAndroid Build Coastguard Worker isTimeMajor,
1271*89c4ff92SAndroid Build Coastguard Worker 0.1);
1272*89c4ff92SAndroid Build Coastguard Worker }
1273*89c4ff92SAndroid Build Coastguard Worker
UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionWithLayerNormTest(std::vector<armnn::BackendId> & backends)1274*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionWithLayerNormTest(
1275*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId>& backends)
1276*89c4ff92SAndroid Build Coastguard Worker {
1277*89c4ff92SAndroid Build Coastguard Worker int32_t batchSize = 3;
1278*89c4ff92SAndroid Build Coastguard Worker int32_t timeSize = 2;
1279*89c4ff92SAndroid Build Coastguard Worker int32_t inputSize = 3;
1280*89c4ff92SAndroid Build Coastguard Worker int32_t outputSize = 4;
1281*89c4ff92SAndroid Build Coastguard Worker int32_t numUnits = 5;
1282*89c4ff92SAndroid Build Coastguard Worker
1283*89c4ff92SAndroid Build Coastguard Worker bool hasInputToInputWeights = true;
1284*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToInputWeights = { -4, -1, -1, -2, 3, -2, 2, 4, 1, -4, -2, 3, 2, 2, -4 };
1285*89c4ff92SAndroid Build Coastguard Worker
1286*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToForgetWeights = { 2, 1, 4, -4, 3, -1, -3, -2, -3, 1, -4, -1, -3, -2, -4 };
1287*89c4ff92SAndroid Build Coastguard Worker
1288*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToCellWeights = { -2, 1, -2, 4, -3, -2, -4, 3, -2, -2, -6, 3, 2, 5, -4 };
1289*89c4ff92SAndroid Build Coastguard Worker
1290*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputToOutputWeights = { 2, 5, -4, 5, 2, -3, 5, 7, 3, -5, 1, -4, -4, -1, -1 };
1291*89c4ff92SAndroid Build Coastguard Worker
1292*89c4ff92SAndroid Build Coastguard Worker bool hasRecurrentToInputWeights = true;
1293*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToInputWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3,
1294*89c4ff92SAndroid Build Coastguard Worker 5, -1, 1, 3, -1, -1, -1, 4, 2, 3 };
1295*89c4ff92SAndroid Build Coastguard Worker
1296*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToForgetWeights = { -1, 1, -1, 1, -3, -4, -1, 4, 2, 3,
1297*89c4ff92SAndroid Build Coastguard Worker 5, -1, 1, 3, -2, -1, -1, 2, 2, 1 };
1298*89c4ff92SAndroid Build Coastguard Worker
1299*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToCellWeights = { -2, -3, -1, -3, -4, 2, 1, -1, 2, 2,
1300*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3, -2, 3, -3, -1, -5, 1, 3 };
1301*89c4ff92SAndroid Build Coastguard Worker
1302*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> recurrentToOutputWeights = { -3, 3, -1, -2, -2, -2, -1, -5, 1, 3,
1303*89c4ff92SAndroid Build Coastguard Worker -4, -1, -1, -1, 2, -1, 5, 1, -3, -4 };
1304*89c4ff92SAndroid Build Coastguard Worker
1305*89c4ff92SAndroid Build Coastguard Worker // tensorInfo5
1306*89c4ff92SAndroid Build Coastguard Worker bool hasCellToInputWeights = true;
1307*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToInputWeights = { 5, 3, 8, -5, 2 };
1308*89c4ff92SAndroid Build Coastguard Worker bool hasCellToForgetWeights = true;
1309*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToForgetWeights = { -2, -7, 5, -3, 4 };
1310*89c4ff92SAndroid Build Coastguard Worker bool hasCellToOutputWeights = true;
1311*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> cellToOutputWeights = { 9, -10 , -5, 5, 1 };
1312*89c4ff92SAndroid Build Coastguard Worker
1313*89c4ff92SAndroid Build Coastguard Worker bool hasInputGateBias = true;
1314*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputGateBias = { 0.03f, 0.15f, 0.22f, 0.38f, 0.05f };
1315*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetGateBias = { 0.1f, -0.3f, -0.2f, 0.1f, 0.4f };
1316*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellBias = { -0.05f, 0.72f, 0.25f, 0.08f, 0.1f };
1317*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputGateBias = { 0.05f, -0.01f, 0.2f, 0.1f, -0.2f };
1318*89c4ff92SAndroid Build Coastguard Worker
1319*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionWeights = true;
1320*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> projectionWeights = { -1, 2, 1, -2, 1, 5, 3, 8, 7, 2,
1321*89c4ff92SAndroid Build Coastguard Worker -4, 2, 5, -4, 3, -2, 3, 8, -7, 2 };
1322*89c4ff92SAndroid Build Coastguard Worker bool hasProjectionBias = true;
1323*89c4ff92SAndroid Build Coastguard Worker std::vector<float> projectionBias(outputSize, 0.f);
1324*89c4ff92SAndroid Build Coastguard Worker
1325*89c4ff92SAndroid Build Coastguard Worker bool hasInputLayerNormWeights = true;
1326*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputLayerNormWeights = { 0.1f, 0.2f, -0.3f, -0.1f, 0.5f };
1327*89c4ff92SAndroid Build Coastguard Worker bool hasForgetLayerNormWeights = true;
1328*89c4ff92SAndroid Build Coastguard Worker std::vector<float> forgetLayerNormWeights = { -0.1f, 0.2f, 0.3f, 0.5f, 0.2f };
1329*89c4ff92SAndroid Build Coastguard Worker bool hasCellLayerNormWeights = true;
1330*89c4ff92SAndroid Build Coastguard Worker std::vector<float> cellLayerNormWeights = { 0.5f, 0.2f, 0.3f, 0.4f, -0.5f };
1331*89c4ff92SAndroid Build Coastguard Worker bool hasOutputLayerNormWeights = true;
1332*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputLayerNormWeights = { 0.6f, -0.2f, -0.2f, 0.5f, 0.1f };
1333*89c4ff92SAndroid Build Coastguard Worker
1334*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1., 8., 3., 4., 5., 4.,
1335*89c4ff92SAndroid Build Coastguard Worker 3., 2., 1., 2., 3., 4.,
1336*89c4ff92SAndroid Build Coastguard Worker 5., 4., 3., 2., 1., 2. };
1337*89c4ff92SAndroid Build Coastguard Worker
1338*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 0.0471276f, 0.0168155f, 0.0789885f, 0.16550f,
1339*89c4ff92SAndroid Build Coastguard Worker 0.0643133f, -0.0400722f, 0.100593f, 0.197722f,
1340*89c4ff92SAndroid Build Coastguard Worker 0.0465562f, -0.0600682f, 0.0622087f, 0.115053f,
1341*89c4ff92SAndroid Build Coastguard Worker 0.056287f, -0.0566218f, 0.0856832f, 0.148484f,
1342*89c4ff92SAndroid Build Coastguard Worker 0.0457859f, -0.0588112f, 0.0623636f, 0.114333f,
1343*89c4ff92SAndroid Build Coastguard Worker 0.0509271f, -0.0754262f, 0.058600f, 0.0801288f };
1344*89c4ff92SAndroid Build Coastguard Worker
1345*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
1346*89c4ff92SAndroid Build Coastguard Worker float clippingThresCell = 10.f;
1347*89c4ff92SAndroid Build Coastguard Worker float clippingThresProj = 0.f;
1348*89c4ff92SAndroid Build Coastguard Worker bool isTimeMajor = false;
1349*89c4ff92SAndroid Build Coastguard Worker
1350*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTestImpl<int8_t>(backends,
1351*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
1352*89c4ff92SAndroid Build Coastguard Worker batchSize,
1353*89c4ff92SAndroid Build Coastguard Worker timeSize,
1354*89c4ff92SAndroid Build Coastguard Worker inputSize,
1355*89c4ff92SAndroid Build Coastguard Worker outputSize,
1356*89c4ff92SAndroid Build Coastguard Worker numUnits,
1357*89c4ff92SAndroid Build Coastguard Worker hasInputToInputWeights,
1358*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights,
1359*89c4ff92SAndroid Build Coastguard Worker inputToForgetWeights,
1360*89c4ff92SAndroid Build Coastguard Worker inputToCellWeights,
1361*89c4ff92SAndroid Build Coastguard Worker inputToOutputWeights,
1362*89c4ff92SAndroid Build Coastguard Worker hasRecurrentToInputWeights,
1363*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights,
1364*89c4ff92SAndroid Build Coastguard Worker recurrentToForgetWeights,
1365*89c4ff92SAndroid Build Coastguard Worker recurrentToCellWeights,
1366*89c4ff92SAndroid Build Coastguard Worker recurrentToOutputWeights,
1367*89c4ff92SAndroid Build Coastguard Worker hasCellToInputWeights,
1368*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights,
1369*89c4ff92SAndroid Build Coastguard Worker hasCellToForgetWeights,
1370*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights,
1371*89c4ff92SAndroid Build Coastguard Worker hasCellToOutputWeights,
1372*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights,
1373*89c4ff92SAndroid Build Coastguard Worker hasInputGateBias,
1374*89c4ff92SAndroid Build Coastguard Worker inputGateBias,
1375*89c4ff92SAndroid Build Coastguard Worker forgetGateBias,
1376*89c4ff92SAndroid Build Coastguard Worker cellBias,
1377*89c4ff92SAndroid Build Coastguard Worker outputGateBias,
1378*89c4ff92SAndroid Build Coastguard Worker hasProjectionWeights,
1379*89c4ff92SAndroid Build Coastguard Worker projectionWeights,
1380*89c4ff92SAndroid Build Coastguard Worker hasProjectionBias,
1381*89c4ff92SAndroid Build Coastguard Worker projectionBias,
1382*89c4ff92SAndroid Build Coastguard Worker hasInputLayerNormWeights,
1383*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights,
1384*89c4ff92SAndroid Build Coastguard Worker hasForgetLayerNormWeights,
1385*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights,
1386*89c4ff92SAndroid Build Coastguard Worker hasCellLayerNormWeights,
1387*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights,
1388*89c4ff92SAndroid Build Coastguard Worker hasOutputLayerNormWeights,
1389*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights,
1390*89c4ff92SAndroid Build Coastguard Worker inputValues,
1391*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
1392*89c4ff92SAndroid Build Coastguard Worker activationFunction,
1393*89c4ff92SAndroid Build Coastguard Worker clippingThresCell,
1394*89c4ff92SAndroid Build Coastguard Worker clippingThresProj,
1395*89c4ff92SAndroid Build Coastguard Worker isTimeMajor,
1396*89c4ff92SAndroid Build Coastguard Worker 0.1);
1397*89c4ff92SAndroid Build Coastguard Worker }
1398*89c4ff92SAndroid Build Coastguard Worker
1399*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("UnidirectionalSequenceLstmTest_CpuRefTests")
1400*89c4ff92SAndroid Build Coastguard Worker {
1401*89c4ff92SAndroid Build Coastguard Worker
1402*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmTest_CpuRef_Test")
1403*89c4ff92SAndroid Build Coastguard Worker {
1404*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1405*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTest(backends);
1406*89c4ff92SAndroid Build Coastguard Worker }
1407*89c4ff92SAndroid Build Coastguard Worker
1408*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmTimeMajorTest_CpuRef_Test")
1409*89c4ff92SAndroid Build Coastguard Worker {
1410*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1411*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmTimeMajorTest(backends);
1412*89c4ff92SAndroid Build Coastguard Worker }
1413*89c4ff92SAndroid Build Coastguard Worker
1414*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionTest_CpuRef_Test")
1415*89c4ff92SAndroid Build Coastguard Worker {
1416*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1417*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionTest(backends);
1418*89c4ff92SAndroid Build Coastguard Worker }
1419*89c4ff92SAndroid Build Coastguard Worker
1420*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest_CpuRef_Test")
1421*89c4ff92SAndroid Build Coastguard Worker {
1422*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1423*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmWithCifgWithPeepholeNoProjectionTest(backends);
1424*89c4ff92SAndroid Build Coastguard Worker }
1425*89c4ff92SAndroid Build Coastguard Worker
1426*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNormTest_CpuRef_Test")
1427*89c4ff92SAndroid Build Coastguard Worker {
1428*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1429*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNormTest(backends);
1430*89c4ff92SAndroid Build Coastguard Worker }
1431*89c4ff92SAndroid Build Coastguard Worker
1432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmInt8Test_CpuRef_Test")
1433*89c4ff92SAndroid Build Coastguard Worker {
1434*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1435*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmInt8Test(backends);
1436*89c4ff92SAndroid Build Coastguard Worker }
1437*89c4ff92SAndroid Build Coastguard Worker
1438*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmTimeInt8TimeMajorTest_CpuRef_Test")
1439*89c4ff92SAndroid Build Coastguard Worker {
1440*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1441*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmInt8TimeMajorTest(backends);
1442*89c4ff92SAndroid Build Coastguard Worker }
1443*89c4ff92SAndroid Build Coastguard Worker
1444*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionTest_CpuRef_Test")
1445*89c4ff92SAndroid Build Coastguard Worker {
1446*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1447*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionTest(backends);
1448*89c4ff92SAndroid Build Coastguard Worker }
1449*89c4ff92SAndroid Build Coastguard Worker
1450*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest_CpuRef_Test")
1451*89c4ff92SAndroid Build Coastguard Worker {
1452*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1453*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest(backends);
1454*89c4ff92SAndroid Build Coastguard Worker }
1455*89c4ff92SAndroid Build Coastguard Worker
1456*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionWithLayerNormTest_CpuRef_Test")
1457*89c4ff92SAndroid Build Coastguard Worker {
1458*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
1459*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmInt8NoCifgWithPeepholeWithProjectionWithLayerNormTest(backends);
1460*89c4ff92SAndroid Build Coastguard Worker }
1461*89c4ff92SAndroid Build Coastguard Worker
1462*89c4ff92SAndroid Build Coastguard Worker } //End of TEST_SUITE("UnidirectionalSequenceLstmTest_CpuRef")
1463*89c4ff92SAndroid Build Coastguard Worker
1464*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate