xref: /aosp_15_r20/external/armnn/delegate/test/LstmTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "TestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreter.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker template <typename T>
CreateLstmTfLiteModel(tflite::TensorType tensorType,int32_t batchSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<T> & inputGateBias,const std::vector<T> & forgetGateBias,const std::vector<T> & cellBias,const std::vector<T> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<T> & projectionBias,bool hasInputLayerNormWeights,const std::vector<T> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<T> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<T> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<T> & outputLayerNormWeights,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,float quantScale=1.0f,int quantOffset=0,float outputQuantScale=2.0f,int outputQuantOffset=0)25*89c4ff92SAndroid Build Coastguard Worker std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
26*89c4ff92SAndroid Build Coastguard Worker                                         int32_t batchSize,
27*89c4ff92SAndroid Build Coastguard Worker                                         int32_t inputSize,
28*89c4ff92SAndroid Build Coastguard Worker                                         int32_t outputSize,
29*89c4ff92SAndroid Build Coastguard Worker                                         int32_t numUnits,
30*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputToInputWeights,
31*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToInputWeights,
32*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToForgetWeights,
33*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToCellWeights,
34*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToOutputWeights,
35*89c4ff92SAndroid Build Coastguard Worker                                         bool hasRecurrentToInputWeights,
36*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToInputWeights,
37*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToForgetWeights,
38*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToCellWeights,
39*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToOutputWeights,
40*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToInputWeights,
41*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToInputWeights,
42*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToForgetWeights,
43*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToForgetWeights,
44*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToOutputWeights,
45*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToOutputWeights,
46*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputGateBias,
47*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputGateBias,
48*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& forgetGateBias,
49*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellBias,
50*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& outputGateBias,
51*89c4ff92SAndroid Build Coastguard Worker                                         bool hasProjectionWeights,
52*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& projectionWeights,
53*89c4ff92SAndroid Build Coastguard Worker                                         bool hasProjectionBias,
54*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& projectionBias,
55*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputLayerNormWeights,
56*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputLayerNormWeights,
57*89c4ff92SAndroid Build Coastguard Worker                                         bool hasForgetLayerNormWeights,
58*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& forgetLayerNormWeights,
59*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellLayerNormWeights,
60*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellLayerNormWeights,
61*89c4ff92SAndroid Build Coastguard Worker                                         bool hasOutputLayerNormWeights,
62*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& outputLayerNormWeights,
63*89c4ff92SAndroid Build Coastguard Worker                                         tflite::ActivationFunctionType activationFunction,
64*89c4ff92SAndroid Build Coastguard Worker                                         float clippingThresCell,
65*89c4ff92SAndroid Build Coastguard Worker                                         float clippingThresProj,
66*89c4ff92SAndroid Build Coastguard Worker                                         float quantScale = 1.0f,
67*89c4ff92SAndroid Build Coastguard Worker                                         int quantOffset  = 0,
68*89c4ff92SAndroid Build Coastguard Worker                                         float outputQuantScale = 2.0f,
69*89c4ff92SAndroid Build Coastguard Worker                                         int outputQuantOffset  = 0)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     std::vector <int32_t> tensorInfo0 {};
73*89c4ff92SAndroid Build Coastguard Worker     std::vector <int32_t> tensorInfo4 {numUnits};
74*89c4ff92SAndroid Build Coastguard Worker     std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
75*89c4ff92SAndroid Build Coastguard Worker     std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape {batchSize , inputSize};
78*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape {batchSize , outputSize};
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
81*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> operatorInputs;
84*89c4ff92SAndroid Build Coastguard Worker     using namespace tflite;
85*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::FlatBufferBuilder flatBufferBuilder;
86*89c4ff92SAndroid Build Coastguard Worker     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<flatbuffers::Offset<Tensor>> tensors;
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     auto quantizationParameters =
90*89c4ff92SAndroid Build Coastguard Worker         CreateQuantizationParameters(flatBufferBuilder,
91*89c4ff92SAndroid Build Coastguard Worker                                      0,
92*89c4ff92SAndroid Build Coastguard Worker                                      0,
93*89c4ff92SAndroid Build Coastguard Worker                                      flatBufferBuilder.CreateVector<float>({ quantScale }),
94*89c4ff92SAndroid Build Coastguard Worker                                      flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     auto outputQuantizationParameters =
97*89c4ff92SAndroid Build Coastguard Worker         CreateQuantizationParameters(flatBufferBuilder,
98*89c4ff92SAndroid Build Coastguard Worker                                      0,
99*89c4ff92SAndroid Build Coastguard Worker                                      0,
100*89c4ff92SAndroid Build Coastguard Worker                                      flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
101*89c4ff92SAndroid Build Coastguard Worker                                      flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
104*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
105*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
106*89c4ff92SAndroid Build Coastguard Worker                                                                            inputShape.size()),
107*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
108*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
109*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("input_0"),
110*89c4ff92SAndroid Build Coastguard Worker                                    quantizationParameters));
111*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     if (hasInputToInputWeights)
114*89c4ff92SAndroid Build Coastguard Worker     {
115*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
116*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
117*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
118*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * inputToInputWeights.size())));
119*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
120*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
121*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo8.size()),
122*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
123*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
124*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputToInputWeights"),
125*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
126*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
127*89c4ff92SAndroid Build Coastguard Worker     }
128*89c4ff92SAndroid Build Coastguard Worker     else
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
134*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
135*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
136*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * inputToForgetWeights.size())));
137*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
138*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
139*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo8.size()),
140*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
141*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
142*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToForgetWeights"),
143*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
144*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
147*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
148*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
149*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * inputToCellWeights.size())));
150*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
151*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
152*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo8.size()),
153*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
154*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
155*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToCellWeights"),
156*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
157*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
160*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
161*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
162*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * inputToOutputWeights.size())));
163*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
164*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
165*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo8.size()),
166*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
167*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
168*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToOutputWeights"),
169*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
170*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker     if (hasRecurrentToInputWeights)
173*89c4ff92SAndroid Build Coastguard Worker     {
174*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(CreateBuffer(
175*89c4ff92SAndroid Build Coastguard Worker             flatBufferBuilder,
176*89c4ff92SAndroid Build Coastguard Worker             flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
177*89c4ff92SAndroid Build Coastguard Worker                                            sizeof(T) * recurrentToInputWeights.size())));
178*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
179*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
180*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo16.size()),
181*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
182*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
183*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("recurrentToInputWeights"),
184*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
185*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
186*89c4ff92SAndroid Build Coastguard Worker     }
187*89c4ff92SAndroid Build Coastguard Worker     else
188*89c4ff92SAndroid Build Coastguard Worker     {
189*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
190*89c4ff92SAndroid Build Coastguard Worker     }
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
193*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
194*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
195*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToForgetWeights.size())));
196*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
197*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
198*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo16.size()),
199*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
200*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
201*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToForgetWeights"),
202*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
203*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
206*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
207*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
208*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToCellWeights.size())));
209*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
210*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
211*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo16.size()),
212*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
213*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
214*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToCellWeights"),
215*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
216*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
219*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
220*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
221*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToOutputWeights.size())));
222*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
223*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
224*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo16.size()),
225*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
226*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1 ,
227*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToOutputWeights"),
228*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
229*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToInputWeights)
232*89c4ff92SAndroid Build Coastguard Worker     {
233*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
234*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
235*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
236*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToInputWeights.size())));
237*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
238*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
239*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
240*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
241*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
242*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToInputWeights"),
243*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
244*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
245*89c4ff92SAndroid Build Coastguard Worker     }
246*89c4ff92SAndroid Build Coastguard Worker     else
247*89c4ff92SAndroid Build Coastguard Worker     {
248*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
249*89c4ff92SAndroid Build Coastguard Worker     }
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToForgetWeights)
252*89c4ff92SAndroid Build Coastguard Worker     {
253*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
254*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
255*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
256*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToForgetWeights.size())));
257*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
258*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
259*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
260*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
261*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
262*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToForgetWeights"),
263*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
264*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker     else
267*89c4ff92SAndroid Build Coastguard Worker     {
268*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
269*89c4ff92SAndroid Build Coastguard Worker     }
270*89c4ff92SAndroid Build Coastguard Worker 
271*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToOutputWeights)
272*89c4ff92SAndroid Build Coastguard Worker     {
273*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
274*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
275*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
276*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToOutputWeights.size())));
277*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
278*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
279*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
280*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
281*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
282*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToOutputWeights"),
283*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
284*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
285*89c4ff92SAndroid Build Coastguard Worker     }
286*89c4ff92SAndroid Build Coastguard Worker     else
287*89c4ff92SAndroid Build Coastguard Worker     {
288*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
289*89c4ff92SAndroid Build Coastguard Worker     }
290*89c4ff92SAndroid Build Coastguard Worker 
291*89c4ff92SAndroid Build Coastguard Worker     if (hasInputGateBias)
292*89c4ff92SAndroid Build Coastguard Worker     {
293*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
294*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
295*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
296*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * inputGateBias.size())));
297*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
298*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
299*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
300*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
301*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
302*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputGateBias"),
303*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
304*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
305*89c4ff92SAndroid Build Coastguard Worker     }
306*89c4ff92SAndroid Build Coastguard Worker     else
307*89c4ff92SAndroid Build Coastguard Worker     {
308*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
309*89c4ff92SAndroid Build Coastguard Worker     }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
312*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
313*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
314*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * forgetGateBias.size())));
315*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
316*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
317*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo4.size()),
318*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
319*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
320*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("forgetGateBias"),
321*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
322*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
325*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
326*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
327*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * cellBias.size())));
328*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
329*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
330*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo4.size()),
331*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
332*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
333*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("cellBias"),
334*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
335*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
338*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
339*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
340*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * outputGateBias.size())));
341*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
342*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
343*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfo4.size()),
344*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
345*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
346*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("outputGateBias"),
347*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
348*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     if (hasProjectionWeights)
351*89c4ff92SAndroid Build Coastguard Worker     {
352*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
353*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
354*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
355*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * projectionWeights.size())));
356*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
357*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
358*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
359*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
360*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
361*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("outputGateBias"),
362*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
363*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
364*89c4ff92SAndroid Build Coastguard Worker     }
365*89c4ff92SAndroid Build Coastguard Worker     else
366*89c4ff92SAndroid Build Coastguard Worker     {
367*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
368*89c4ff92SAndroid Build Coastguard Worker     }
369*89c4ff92SAndroid Build Coastguard Worker 
370*89c4ff92SAndroid Build Coastguard Worker     if (hasProjectionBias)
371*89c4ff92SAndroid Build Coastguard Worker     {
372*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
373*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
374*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
375*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * projectionBias.size())));
376*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
377*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
378*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
379*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
380*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
381*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("projectionBias"),
382*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
383*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
384*89c4ff92SAndroid Build Coastguard Worker     }
385*89c4ff92SAndroid Build Coastguard Worker     else
386*89c4ff92SAndroid Build Coastguard Worker     {
387*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
388*89c4ff92SAndroid Build Coastguard Worker     }
389*89c4ff92SAndroid Build Coastguard Worker 
390*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
391*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
392*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
393*89c4ff92SAndroid Build Coastguard Worker                                                                            outputStateInDimensions.size()),
394*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
395*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
396*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("outputStateInInfo"),
397*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters,
398*89c4ff92SAndroid Build Coastguard Worker                                    true));
399*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
402*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
403*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
404*89c4ff92SAndroid Build Coastguard Worker                                                                            cellStateInDimensions.size()),
405*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
406*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
407*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("cellStateInInfo"),
408*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters,
409*89c4ff92SAndroid Build Coastguard Worker                                    true));
410*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(buffers.size() - 1);
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker     if (hasInputLayerNormWeights)
413*89c4ff92SAndroid Build Coastguard Worker     {
414*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
415*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
416*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
417*89c4ff92SAndroid Build Coastguard Worker                                               reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
418*89c4ff92SAndroid Build Coastguard Worker                                               sizeof(T) * inputLayerNormWeights.size())));
419*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
420*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
421*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
422*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
423*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
424*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputLayerNormWeights"),
425*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
426*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
427*89c4ff92SAndroid Build Coastguard Worker     }
428*89c4ff92SAndroid Build Coastguard Worker     else
429*89c4ff92SAndroid Build Coastguard Worker     {
430*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
431*89c4ff92SAndroid Build Coastguard Worker     }
432*89c4ff92SAndroid Build Coastguard Worker 
433*89c4ff92SAndroid Build Coastguard Worker     if (hasForgetLayerNormWeights)
434*89c4ff92SAndroid Build Coastguard Worker     {
435*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
436*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
437*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
438*89c4ff92SAndroid Build Coastguard Worker                                               reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
439*89c4ff92SAndroid Build Coastguard Worker                                               sizeof(T) * forgetLayerNormWeights.size())));
440*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
441*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
442*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
443*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
444*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
445*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("forgetLayerNormWeights"),
446*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
447*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
448*89c4ff92SAndroid Build Coastguard Worker     }
449*89c4ff92SAndroid Build Coastguard Worker     else
450*89c4ff92SAndroid Build Coastguard Worker     {
451*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
452*89c4ff92SAndroid Build Coastguard Worker     }
453*89c4ff92SAndroid Build Coastguard Worker 
454*89c4ff92SAndroid Build Coastguard Worker     if (hasCellLayerNormWeights)
455*89c4ff92SAndroid Build Coastguard Worker     {
456*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
457*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
458*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
459*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellLayerNormWeights.size())));
460*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
461*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
462*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
463*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
464*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
465*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellLayerNormWeights"),
466*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
467*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
468*89c4ff92SAndroid Build Coastguard Worker     }
469*89c4ff92SAndroid Build Coastguard Worker     else
470*89c4ff92SAndroid Build Coastguard Worker     {
471*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
472*89c4ff92SAndroid Build Coastguard Worker     }
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker     if (hasOutputLayerNormWeights)
475*89c4ff92SAndroid Build Coastguard Worker     {
476*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
477*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
478*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
479*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
480*89c4ff92SAndroid Build Coastguard Worker                              sizeof(T) * outputLayerNormWeights.size())));
481*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
482*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
483*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfo4.size()),
484*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
485*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
486*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("outputLayerNormWeights"),
487*89c4ff92SAndroid Build Coastguard Worker                                        outputQuantizationParameters));
488*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(buffers.size() - 1);
489*89c4ff92SAndroid Build Coastguard Worker     }
490*89c4ff92SAndroid Build Coastguard Worker     else
491*89c4ff92SAndroid Build Coastguard Worker     {
492*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
493*89c4ff92SAndroid Build Coastguard Worker     }
494*89c4ff92SAndroid Build Coastguard Worker     int outputBufferId = buffers.size();
495*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
496*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
497*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
498*89c4ff92SAndroid Build Coastguard Worker                                                                            outputShape.size()),
499*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
500*89c4ff92SAndroid Build Coastguard Worker                                    outputBufferId,
501*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("output"),
502*89c4ff92SAndroid Build Coastguard Worker                                    outputQuantizationParameters));
503*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> operatorOutputs;
504*89c4ff92SAndroid Build Coastguard Worker     operatorOutputs.push_back(buffers.size() - 1);
505*89c4ff92SAndroid Build Coastguard Worker 
506*89c4ff92SAndroid Build Coastguard Worker     // create operator
507*89c4ff92SAndroid Build Coastguard Worker     tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
508*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<void> operatorBuiltinOptions =
509*89c4ff92SAndroid Build Coastguard Worker         CreateLSTMOptions(flatBufferBuilder,
510*89c4ff92SAndroid Build Coastguard Worker                           activationFunction,
511*89c4ff92SAndroid Build Coastguard Worker                           clippingThresCell,
512*89c4ff92SAndroid Build Coastguard Worker                           clippingThresProj).Union();
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset <Operator> lstmOperator =
515*89c4ff92SAndroid Build Coastguard Worker         CreateOperator(flatBufferBuilder,
516*89c4ff92SAndroid Build Coastguard Worker                        0,
517*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
518*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
519*89c4ff92SAndroid Build Coastguard Worker                        operatorBuiltinOptionsType, operatorBuiltinOptions);
520*89c4ff92SAndroid Build Coastguard Worker 
521*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset <SubGraph> subgraph =
522*89c4ff92SAndroid Build Coastguard Worker         CreateSubGraph(flatBufferBuilder,
523*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
524*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
525*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
526*89c4ff92SAndroid Build Coastguard Worker                        flatBufferBuilder.CreateVector(&lstmOperator, 1));
527*89c4ff92SAndroid Build Coastguard Worker 
528*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset <flatbuffers::String> modelDescription =
529*89c4ff92SAndroid Build Coastguard Worker         flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
530*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
531*89c4ff92SAndroid Build Coastguard Worker                                                                          tflite::BuiltinOperator_LSTM);
532*89c4ff92SAndroid Build Coastguard Worker 
533*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset <Model> flatbufferModel =
534*89c4ff92SAndroid Build Coastguard Worker         CreateModel(flatBufferBuilder,
535*89c4ff92SAndroid Build Coastguard Worker                     TFLITE_SCHEMA_VERSION,
536*89c4ff92SAndroid Build Coastguard Worker                     flatBufferBuilder.CreateVector(&operatorCode, 1),
537*89c4ff92SAndroid Build Coastguard Worker                     flatBufferBuilder.CreateVector(&subgraph, 1),
538*89c4ff92SAndroid Build Coastguard Worker                     modelDescription,
539*89c4ff92SAndroid Build Coastguard Worker                     flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
540*89c4ff92SAndroid Build Coastguard Worker 
541*89c4ff92SAndroid Build Coastguard Worker     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
542*89c4ff92SAndroid Build Coastguard Worker 
543*89c4ff92SAndroid Build Coastguard Worker     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
544*89c4ff92SAndroid Build Coastguard Worker                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
545*89c4ff92SAndroid Build Coastguard Worker }
546*89c4ff92SAndroid Build Coastguard Worker 
547*89c4ff92SAndroid Build Coastguard Worker template <typename T>
LstmTestImpl(std::vector<armnn::BackendId> & backends,tflite::TensorType tensorType,int32_t batchSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<T> & inputGateBias,const std::vector<T> & forgetGateBias,const std::vector<T> & cellBias,const std::vector<T> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<T> & projectionBias,bool hasInputLayerNormWeights,const std::vector<T> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<T> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<T> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<T> & outputLayerNormWeights,std::vector<T> & inputValues,std::vector<T> & expectedOutputValues,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj)548*89c4ff92SAndroid Build Coastguard Worker void LstmTestImpl(std::vector<armnn::BackendId>& backends,
549*89c4ff92SAndroid Build Coastguard Worker                   tflite::TensorType tensorType,
550*89c4ff92SAndroid Build Coastguard Worker                   int32_t batchSize,
551*89c4ff92SAndroid Build Coastguard Worker                   int32_t inputSize,
552*89c4ff92SAndroid Build Coastguard Worker                   int32_t outputSize,
553*89c4ff92SAndroid Build Coastguard Worker                   int32_t numUnits,
554*89c4ff92SAndroid Build Coastguard Worker                   bool hasInputToInputWeights,
555*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputToInputWeights,
556*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputToForgetWeights,
557*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputToCellWeights,
558*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputToOutputWeights,
559*89c4ff92SAndroid Build Coastguard Worker                   bool hasRecurrentToInputWeights,
560*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& recurrentToInputWeights,
561*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& recurrentToForgetWeights,
562*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& recurrentToCellWeights,
563*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& recurrentToOutputWeights,
564*89c4ff92SAndroid Build Coastguard Worker                   bool hasCellToInputWeights,
565*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& cellToInputWeights,
566*89c4ff92SAndroid Build Coastguard Worker                   bool hasCellToForgetWeights,
567*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& cellToForgetWeights,
568*89c4ff92SAndroid Build Coastguard Worker                   bool hasCellToOutputWeights,
569*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& cellToOutputWeights,
570*89c4ff92SAndroid Build Coastguard Worker                   bool hasInputGateBias,
571*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputGateBias,
572*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& forgetGateBias,
573*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& cellBias,
574*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& outputGateBias,
575*89c4ff92SAndroid Build Coastguard Worker                   bool hasProjectionWeights,
576*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& projectionWeights,
577*89c4ff92SAndroid Build Coastguard Worker                   bool hasProjectionBias,
578*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& projectionBias,
579*89c4ff92SAndroid Build Coastguard Worker                   bool hasInputLayerNormWeights,
580*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& inputLayerNormWeights,
581*89c4ff92SAndroid Build Coastguard Worker                   bool hasForgetLayerNormWeights,
582*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& forgetLayerNormWeights,
583*89c4ff92SAndroid Build Coastguard Worker                   bool hasCellLayerNormWeights,
584*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& cellLayerNormWeights,
585*89c4ff92SAndroid Build Coastguard Worker                   bool hasOutputLayerNormWeights,
586*89c4ff92SAndroid Build Coastguard Worker                   const std::vector<T>& outputLayerNormWeights,
587*89c4ff92SAndroid Build Coastguard Worker                   std::vector<T>& inputValues,
588*89c4ff92SAndroid Build Coastguard Worker                   std::vector<T>& expectedOutputValues,
589*89c4ff92SAndroid Build Coastguard Worker                   tflite::ActivationFunctionType activationFunction,
590*89c4ff92SAndroid Build Coastguard Worker                   float clippingThresCell,
591*89c4ff92SAndroid Build Coastguard Worker                   float clippingThresProj)
592*89c4ff92SAndroid Build Coastguard Worker {
593*89c4ff92SAndroid Build Coastguard Worker     using namespace delegateTestInterpreter;
594*89c4ff92SAndroid Build Coastguard Worker 
595*89c4ff92SAndroid Build Coastguard Worker     std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
596*89c4ff92SAndroid Build Coastguard Worker                                                           batchSize,
597*89c4ff92SAndroid Build Coastguard Worker                                                           inputSize,
598*89c4ff92SAndroid Build Coastguard Worker                                                           outputSize,
599*89c4ff92SAndroid Build Coastguard Worker                                                           numUnits,
600*89c4ff92SAndroid Build Coastguard Worker                                                           hasInputToInputWeights,
601*89c4ff92SAndroid Build Coastguard Worker                                                           inputToInputWeights,
602*89c4ff92SAndroid Build Coastguard Worker                                                           inputToForgetWeights,
603*89c4ff92SAndroid Build Coastguard Worker                                                           inputToCellWeights,
604*89c4ff92SAndroid Build Coastguard Worker                                                           inputToOutputWeights,
605*89c4ff92SAndroid Build Coastguard Worker                                                           hasRecurrentToInputWeights,
606*89c4ff92SAndroid Build Coastguard Worker                                                           recurrentToInputWeights,
607*89c4ff92SAndroid Build Coastguard Worker                                                           recurrentToForgetWeights,
608*89c4ff92SAndroid Build Coastguard Worker                                                           recurrentToCellWeights,
609*89c4ff92SAndroid Build Coastguard Worker                                                           recurrentToOutputWeights,
610*89c4ff92SAndroid Build Coastguard Worker                                                           hasCellToInputWeights,
611*89c4ff92SAndroid Build Coastguard Worker                                                           cellToInputWeights,
612*89c4ff92SAndroid Build Coastguard Worker                                                           hasCellToForgetWeights,
613*89c4ff92SAndroid Build Coastguard Worker                                                           cellToForgetWeights,
614*89c4ff92SAndroid Build Coastguard Worker                                                           hasCellToOutputWeights,
615*89c4ff92SAndroid Build Coastguard Worker                                                           cellToOutputWeights,
616*89c4ff92SAndroid Build Coastguard Worker                                                           hasInputGateBias,
617*89c4ff92SAndroid Build Coastguard Worker                                                           inputGateBias,
618*89c4ff92SAndroid Build Coastguard Worker                                                           forgetGateBias,
619*89c4ff92SAndroid Build Coastguard Worker                                                           cellBias,
620*89c4ff92SAndroid Build Coastguard Worker                                                           outputGateBias,
621*89c4ff92SAndroid Build Coastguard Worker                                                           hasProjectionWeights,
622*89c4ff92SAndroid Build Coastguard Worker                                                           projectionWeights,
623*89c4ff92SAndroid Build Coastguard Worker                                                           hasProjectionBias,
624*89c4ff92SAndroid Build Coastguard Worker                                                           projectionBias,
625*89c4ff92SAndroid Build Coastguard Worker                                                           hasInputLayerNormWeights,
626*89c4ff92SAndroid Build Coastguard Worker                                                           inputLayerNormWeights,
627*89c4ff92SAndroid Build Coastguard Worker                                                           hasForgetLayerNormWeights,
628*89c4ff92SAndroid Build Coastguard Worker                                                           forgetLayerNormWeights,
629*89c4ff92SAndroid Build Coastguard Worker                                                           hasCellLayerNormWeights,
630*89c4ff92SAndroid Build Coastguard Worker                                                           cellLayerNormWeights,
631*89c4ff92SAndroid Build Coastguard Worker                                                           hasOutputLayerNormWeights,
632*89c4ff92SAndroid Build Coastguard Worker                                                           outputLayerNormWeights,
633*89c4ff92SAndroid Build Coastguard Worker                                                           activationFunction,
634*89c4ff92SAndroid Build Coastguard Worker                                                           clippingThresCell,
635*89c4ff92SAndroid Build Coastguard Worker                                                           clippingThresProj);
636*89c4ff92SAndroid Build Coastguard Worker 
637*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape {batchSize , outputSize};
638*89c4ff92SAndroid Build Coastguard Worker 
639*89c4ff92SAndroid Build Coastguard Worker     // Setup interpreter with just TFLite Runtime.
640*89c4ff92SAndroid Build Coastguard Worker     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
641*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
642*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
643*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
644*89c4ff92SAndroid Build Coastguard Worker     std::vector<T>       tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
645*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
646*89c4ff92SAndroid Build Coastguard Worker 
647*89c4ff92SAndroid Build Coastguard Worker     // Setup interpreter with Arm NN Delegate applied.
648*89c4ff92SAndroid Build Coastguard Worker     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
649*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
650*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
651*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
652*89c4ff92SAndroid Build Coastguard Worker     std::vector<T>       armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
653*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
654*89c4ff92SAndroid Build Coastguard Worker 
655*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
656*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
657*89c4ff92SAndroid Build Coastguard Worker 
658*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter.Cleanup();
659*89c4ff92SAndroid Build Coastguard Worker     armnnInterpreter.Cleanup();
660*89c4ff92SAndroid Build Coastguard Worker }
661*89c4ff92SAndroid Build Coastguard Worker 
662*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace