xref: /aosp_15_r20/external/armnn/delegate/test/LstmTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "LstmTestHelper.hpp"
7 
8 #include <armnn_delegate.hpp>
9 
10 #include <flatbuffers/flatbuffers.h>
11 #include <schema_generated.h>
12 #include <doctest/doctest.h>
13 
14 namespace armnnDelegate
15 {
16 
LstmTest(std::vector<armnn::BackendId> & backends)17 void LstmTest(std::vector<armnn::BackendId>& backends)
18 {
19     int32_t batchSize = 2;
20     int32_t inputSize = 2;
21     int32_t outputSize = 4;
22     // cellSize and outputSize have the same size when there is no projection.
23     int32_t numUnits = outputSize;
24 
25     std::vector<int32_t> inputShape {batchSize , inputSize};
26     std::vector<int32_t> cellStateInTensorInfo {batchSize , numUnits};
27     std::vector<int32_t> outputStateInTensorInfo {batchSize , outputSize};
28 
29     std::vector<int32_t> scratchBufferTensorInfo {batchSize, numUnits * 4};
30     std::vector<int32_t> cellStateOutTensorInfo {batchSize, numUnits};
31     std::vector<int32_t> outputStateOutTensorInfo {batchSize, outputSize};
32     std::vector<int32_t> outputTensorInfo {batchSize, outputSize};
33 
34     std::vector<int32_t> tensorInfo4 {numUnits};
35     std::vector<int32_t> tensorInfo8 {numUnits, 2};
36     std::vector<int32_t> tensorInfo16 {numUnits, 4};
37 
38     //tensorInfo8,
39     bool hasInputToInputWeights = true;
40     std::vector<float> inputToInputWeights {-0.45018822f, -0.02338299f, -0.0870589f,
41                                             -0.34550029f, 0.04266912f, -0.15680569f,
42                                             -0.34856534f, 0.43890524f};
43 
44     std::vector<float> inputToForgetWeights {0.09701663f, 0.20334584f, -0.50592935f,
45                                              -0.31343272f, -0.40032279f, 0.44781327f,
46                                              0.01387155f, -0.35593212f};
47 
48     std::vector<float> inputToCellWeights {-0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
49                                            -0.20583314f, 0.44344562f, 0.22077113f,
50                                            -0.29909778f};
51 
52     std::vector<float> inputToOutputWeights {-0.25065863f, -0.28290087f, 0.04613829f,
53                                              0.40525138f, 0.44272184f, 0.03897077f,
54                                              -0.1556896f, 0.19487578f};
55 
56     //tensorInfo16,
57     bool hasRecurrentToInputWeights = true;
58     std::vector<float> recurrentToInputWeights {-0.0063535f, -0.2042388f, 0.31454784f,
59                                                 -0.35746509f, 0.28902304f, 0.08183324f,
60                                                 -0.16555229f, 0.02286911f, -0.13566875f,
61                                                 0.03034258f, 0.48091322f, -0.12528998f,
62                                                 0.24077177f, -0.51332325f, -0.33502164f,
63                                                 0.10629296f};
64 
65     std::vector<float> recurrentToForgetWeights {-0.48684245f, -0.06655136f, 0.42224967f,
66                                                  0.2112639f, 0.27654213f, 0.20864892f,
67                                                  -0.07646349f, 0.45877004f, 0.00141793f,
68                                                  -0.14609534f, 0.36447752f, 0.09196436f,
69                                                  0.28053468f, 0.01560611f, -0.20127171f,
70                                                  -0.01140004f};
71 
72     std::vector<float> recurrentToCellWeights {-0.3407414f, 0.24443203f, -0.2078532f,
73                                                0.26320225f, 0.05695659f, -0.00123841f,
74                                                -0.4744786f, -0.35869038f, -0.06418842f,
75                                                -0.13502428f, -0.501764f, 0.22830659f,
76                                                -0.46367589f, 0.26016325f, -0.03894562f,
77                                                -0.16368064f};
78 
79     std::vector<float> recurrentToOutputWeights {0.43385774f, -0.17194885f, 0.2718237f,
80                                                  0.09215671f, 0.24107647f, -0.39835793f,
81                                                  0.18212086f, 0.01301402f, 0.48572797f,
82                                                  -0.50656658f, 0.20047462f, -0.20607421f,
83                                                  -0.51818722f, -0.15390486f, 0.0468148f,
84                                                  0.39922136f};
85     // tensorInfo4
86     bool hasCellToInputWeights = false;
87     std::vector<float> cellToInputWeights {};
88     bool hasCellToForgetWeights = false;
89     std::vector<float> cellToForgetWeights {};
90     bool hasCellToOutputWeights = false;
91     std::vector<float> cellToOutputWeights {};
92 
93     bool hasInputGateBias = true;
94     std::vector<float> inputGateBias {0., 0., 0., 0.};
95     std::vector<float> forgetGateBias {1., 1., 1., 1.};
96     std::vector<float> cellBias {0., 0., 0., 0.};
97     std::vector<float> outputGateBias {0., 0., 0., 0.};
98 
99     bool hasProjectionWeights = false;
100     std::vector<float> projectionWeights;
101     bool hasProjectionBias = false;
102     std::vector<float> projectionBias;
103 
104     bool hasInputLayerNormWeights = false;
105     std::vector<float> inputLayerNormWeights;
106     bool hasForgetLayerNormWeights = false;
107     std::vector<float> forgetLayerNormWeights;
108     bool hasCellLayerNormWeights = false;
109     std::vector<float> cellLayerNormWeights;
110     bool hasOutputLayerNormWeights = false;
111     std::vector<float> outputLayerNormWeights;
112 
113     std::vector<float> inputValues {2., 3., 3., 4.};
114     std::vector<float> expectedOutputValues {-0.02973187f, 0.1229473f,   0.20885126f, -0.15358765f,
115                                              -0.0185422f,   0.11281417f,  0.24466537f, -0.1826292f};
116 
117     tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
118     float clippingThresCell = 0.f;
119     float clippingThresProj = 0.f;
120 
121     LstmTestImpl<float>(backends,
122                         ::tflite::TensorType_FLOAT32,
123                         batchSize,
124                         inputSize,
125                         outputSize,
126                         numUnits,
127                         hasInputToInputWeights,
128                         inputToInputWeights,
129                         inputToForgetWeights,
130                         inputToCellWeights,
131                         inputToOutputWeights,
132                         hasRecurrentToInputWeights,
133                         recurrentToInputWeights,
134                         recurrentToForgetWeights,
135                         recurrentToCellWeights,
136                         recurrentToOutputWeights,
137                         hasCellToInputWeights,
138                         cellToInputWeights,
139                         hasCellToForgetWeights,
140                         cellToForgetWeights,
141                         hasCellToOutputWeights,
142                         cellToOutputWeights,
143                         hasInputGateBias,
144                         inputGateBias,
145                         forgetGateBias,
146                         cellBias,
147                         outputGateBias,
148                         hasProjectionWeights,
149                         projectionWeights,
150                         hasProjectionBias,
151                         projectionBias,
152                         hasInputLayerNormWeights,
153                         inputLayerNormWeights,
154                         hasForgetLayerNormWeights,
155                         forgetLayerNormWeights,
156                         hasCellLayerNormWeights,
157                         cellLayerNormWeights,
158                         hasOutputLayerNormWeights,
159                         outputLayerNormWeights,
160                         inputValues,
161                         expectedOutputValues,
162                         activationFunction,
163                         clippingThresCell,
164                         clippingThresProj);
165 }
166 
167 TEST_SUITE("LstmTest_CpuRefTests")
168 {
169 
170 TEST_CASE ("LstmTest_CpuRef_Test")
171 {
172     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
173     LstmTest(backends);
174 }
175 
176 } //End of TEST_SUITE("Convolution2dTest_CpuRef")
177 
178 TEST_SUITE("LstmTest_CpuAccTests")
179 {
180 
181 TEST_CASE ("LstmTest_CpuAcc_Test")
182 {
183     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
184     LstmTest(backends);
185 }
186 
187 } //End of TEST_SUITE("Convolution2dTest_CpuAcc")
188 
189 } // namespace armnnDelegate