xref: /aosp_15_r20/external/armnn/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "TestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreter.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
22*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #include <initializer_list>
28*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
29*89c4ff92SAndroid Build Coastguard Worker #include <vector>
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker namespace
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,int32_t batchSize,int32_t timeSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<float> & inputGateBias,const std::vector<float> & forgetGateBias,const std::vector<float> & cellBias,const std::vector<float> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<float> & projectionBias,bool hasInputLayerNormWeights,const std::vector<float> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<float> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<float> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<float> & outputLayerNormWeights,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,bool isTimeMajor,float quantScale,int quantOffset=0)35*89c4ff92SAndroid Build Coastguard Worker std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
36*89c4ff92SAndroid Build Coastguard Worker                                                               int32_t batchSize,
37*89c4ff92SAndroid Build Coastguard Worker                                                               int32_t timeSize,
38*89c4ff92SAndroid Build Coastguard Worker                                                               int32_t inputSize,
39*89c4ff92SAndroid Build Coastguard Worker                                                               int32_t outputSize,
40*89c4ff92SAndroid Build Coastguard Worker                                                               int32_t numUnits,
41*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasInputToInputWeights,
42*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& inputToInputWeights,
43*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& inputToForgetWeights,
44*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& inputToCellWeights,
45*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& inputToOutputWeights,
46*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasRecurrentToInputWeights,
47*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& recurrentToInputWeights,
48*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& recurrentToForgetWeights,
49*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& recurrentToCellWeights,
50*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& recurrentToOutputWeights,
51*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasCellToInputWeights,
52*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& cellToInputWeights,
53*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasCellToForgetWeights,
54*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& cellToForgetWeights,
55*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasCellToOutputWeights,
56*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& cellToOutputWeights,
57*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasInputGateBias,
58*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& inputGateBias,
59*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& forgetGateBias,
60*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& cellBias,
61*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& outputGateBias,
62*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasProjectionWeights,
63*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<T>& projectionWeights,
64*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasProjectionBias,
65*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& projectionBias,
66*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasInputLayerNormWeights,
67*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& inputLayerNormWeights,
68*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasForgetLayerNormWeights,
69*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& forgetLayerNormWeights,
70*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasCellLayerNormWeights,
71*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& cellLayerNormWeights,
72*89c4ff92SAndroid Build Coastguard Worker                                                               bool hasOutputLayerNormWeights,
73*89c4ff92SAndroid Build Coastguard Worker                                                               const std::vector<float>& outputLayerNormWeights,
74*89c4ff92SAndroid Build Coastguard Worker                                                               tflite::ActivationFunctionType activationFunction,
75*89c4ff92SAndroid Build Coastguard Worker                                                               float clippingThresCell,
76*89c4ff92SAndroid Build Coastguard Worker                                                               float clippingThresProj,
77*89c4ff92SAndroid Build Coastguard Worker                                                               bool isTimeMajor,
78*89c4ff92SAndroid Build Coastguard Worker                                                               float quantScale,
79*89c4ff92SAndroid Build Coastguard Worker                                                               int quantOffset = 0)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorInfo0{};
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorInfoNumUnits{numUnits};
84*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
85*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape;
88*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape;
89*89c4ff92SAndroid Build Coastguard Worker     if (isTimeMajor)
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         inputShape  = {timeSize, batchSize, inputSize};
92*89c4ff92SAndroid Build Coastguard Worker         outputShape = {timeSize, batchSize, outputSize};
93*89c4ff92SAndroid Build Coastguard Worker     }
94*89c4ff92SAndroid Build Coastguard Worker     else
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         inputShape  = {batchSize, timeSize, inputSize};
97*89c4ff92SAndroid Build Coastguard Worker         outputShape = {batchSize, timeSize, outputSize};
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
100*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
101*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
102*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> projectionBiasDimensions{outputSize};
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> operatorInputs;
105*89c4ff92SAndroid Build Coastguard Worker     using namespace tflite;
106*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::FlatBufferBuilder                   flatBufferBuilder;
107*89c4ff92SAndroid Build Coastguard Worker     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
108*89c4ff92SAndroid Build Coastguard Worker     std::vector<flatbuffers::Offset<Tensor>>         tensors;
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     auto quantizationParameters =
111*89c4ff92SAndroid Build Coastguard Worker              CreateQuantizationParameters(flatBufferBuilder,
112*89c4ff92SAndroid Build Coastguard Worker                                           0,
113*89c4ff92SAndroid Build Coastguard Worker                                           0,
114*89c4ff92SAndroid Build Coastguard Worker                                           flatBufferBuilder.CreateVector<float>({1.0f}),
115*89c4ff92SAndroid Build Coastguard Worker                                           flatBufferBuilder.CreateVector<int64_t>({0}));
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker     auto weightQuantizationParameters =
118*89c4ff92SAndroid Build Coastguard Worker              CreateQuantizationParameters(flatBufferBuilder,
119*89c4ff92SAndroid Build Coastguard Worker                                           0,
120*89c4ff92SAndroid Build Coastguard Worker                                           0,
121*89c4ff92SAndroid Build Coastguard Worker                                           flatBufferBuilder.CreateVector<float>({quantScale}),
122*89c4ff92SAndroid Build Coastguard Worker                                           flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
125*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
126*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
127*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
128*89c4ff92SAndroid Build Coastguard Worker                                                                            inputShape.size()),
129*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
130*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
131*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("input_0")));
132*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     if (hasInputToInputWeights)
135*89c4ff92SAndroid Build Coastguard Worker     {
136*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
137*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
138*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
139*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(inputToInputWeights.data()),
140*89c4ff92SAndroid Build Coastguard Worker                              sizeof(T) * inputToInputWeights.size())));
141*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
142*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
143*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoInputSize.size()),
144*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
145*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
146*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputToInputWeights"),
147*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
148*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
149*89c4ff92SAndroid Build Coastguard Worker     }
150*89c4ff92SAndroid Build Coastguard Worker     else
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
156*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
157*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(
158*89c4ff92SAndroid Build Coastguard Worker                          reinterpret_cast<const uint8_t*>(inputToForgetWeights.data()),
159*89c4ff92SAndroid Build Coastguard Worker                          sizeof(T) * inputToForgetWeights.size())));
160*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
161*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
162*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoInputSize.size()),
163*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
164*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
165*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToForgetWeights"),
166*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
167*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
170*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
171*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(
172*89c4ff92SAndroid Build Coastguard Worker                          reinterpret_cast<const uint8_t*>(inputToCellWeights.data()),
173*89c4ff92SAndroid Build Coastguard Worker                          sizeof(T) * inputToCellWeights.size())));
174*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
175*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
176*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoInputSize.size()),
177*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
178*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
179*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToCellWeights"),
180*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
181*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
184*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
185*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(
186*89c4ff92SAndroid Build Coastguard Worker                          reinterpret_cast<const uint8_t*>(inputToOutputWeights.data()),
187*89c4ff92SAndroid Build Coastguard Worker                          sizeof(T) * inputToOutputWeights.size())));
188*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
189*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
190*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoInputSize.size()),
191*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
192*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
193*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("inputToOutputWeights"),
194*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
195*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker     if (hasRecurrentToInputWeights)
198*89c4ff92SAndroid Build Coastguard Worker     {
199*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(CreateBuffer(
200*89c4ff92SAndroid Build Coastguard Worker             flatBufferBuilder,
201*89c4ff92SAndroid Build Coastguard Worker             flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
202*89c4ff92SAndroid Build Coastguard Worker                                            sizeof(T) * recurrentToInputWeights.size())));
203*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
204*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
205*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoOutputSize.size()),
206*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
207*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
208*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("recurrentToInputWeights"),
209*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
210*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
211*89c4ff92SAndroid Build Coastguard Worker     }
212*89c4ff92SAndroid Build Coastguard Worker     else
213*89c4ff92SAndroid Build Coastguard Worker     {
214*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
215*89c4ff92SAndroid Build Coastguard Worker     }
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
218*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
219*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
220*89c4ff92SAndroid Build Coastguard Worker                                                         recurrentToForgetWeights.data()),
221*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToForgetWeights.size())));
222*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
223*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
224*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoOutputSize.size()),
225*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
226*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
227*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToForgetWeights"),
228*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
229*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
232*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
233*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
234*89c4ff92SAndroid Build Coastguard Worker                                                         recurrentToCellWeights.data()),
235*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToCellWeights.size())));
236*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
237*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
238*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoOutputSize.size()),
239*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
240*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
241*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToCellWeights"),
242*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
243*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
246*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
247*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
248*89c4ff92SAndroid Build Coastguard Worker                                                         recurrentToOutputWeights.data()),
249*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(T) * recurrentToOutputWeights.size())));
250*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
251*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
252*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoOutputSize.size()),
253*89c4ff92SAndroid Build Coastguard Worker                                    tensorType,
254*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
255*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("recurrentToOutputWeights"),
256*89c4ff92SAndroid Build Coastguard Worker                                    weightQuantizationParameters));
257*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
258*89c4ff92SAndroid Build Coastguard Worker 
259*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToInputWeights)
260*89c4ff92SAndroid Build Coastguard Worker     {
261*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
262*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
263*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
264*89c4ff92SAndroid Build Coastguard Worker                                                             cellToInputWeights.data()),
265*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToInputWeights.size())));
266*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
267*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
268*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
269*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
270*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
271*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToInputWeights"),
272*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
273*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
274*89c4ff92SAndroid Build Coastguard Worker     }
275*89c4ff92SAndroid Build Coastguard Worker     else
276*89c4ff92SAndroid Build Coastguard Worker     {
277*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
278*89c4ff92SAndroid Build Coastguard Worker     }
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToForgetWeights)
281*89c4ff92SAndroid Build Coastguard Worker     {
282*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
283*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
284*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
285*89c4ff92SAndroid Build Coastguard Worker                                                             cellToForgetWeights.data()),
286*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToForgetWeights.size())));
287*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
288*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
289*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
290*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
291*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
292*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToForgetWeights"),
293*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
294*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
295*89c4ff92SAndroid Build Coastguard Worker     }
296*89c4ff92SAndroid Build Coastguard Worker     else
297*89c4ff92SAndroid Build Coastguard Worker     {
298*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
299*89c4ff92SAndroid Build Coastguard Worker     }
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     if (hasCellToOutputWeights)
302*89c4ff92SAndroid Build Coastguard Worker     {
303*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
304*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
305*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
306*89c4ff92SAndroid Build Coastguard Worker                                                             cellToOutputWeights.data()),
307*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(T) * cellToOutputWeights.size())));
308*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
309*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
310*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
311*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
312*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
313*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellToOutputWeights"),
314*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
315*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
316*89c4ff92SAndroid Build Coastguard Worker     }
317*89c4ff92SAndroid Build Coastguard Worker     else
318*89c4ff92SAndroid Build Coastguard Worker     {
319*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
320*89c4ff92SAndroid Build Coastguard Worker     }
321*89c4ff92SAndroid Build Coastguard Worker 
322*89c4ff92SAndroid Build Coastguard Worker     if (hasInputGateBias)
323*89c4ff92SAndroid Build Coastguard Worker     {
324*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
325*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
326*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
327*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(float) * inputGateBias.size())));
328*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
329*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
330*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
331*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
332*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
333*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputGateBias")));
334*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker     else
337*89c4ff92SAndroid Build Coastguard Worker     {
338*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
339*89c4ff92SAndroid Build Coastguard Worker     }
340*89c4ff92SAndroid Build Coastguard Worker 
341*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
342*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
343*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(forgetGateBias.data()),
344*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(float) * forgetGateBias.size())));
345*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
346*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
347*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoNumUnits.size()),
348*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
349*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
350*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("forgetGateBias")));
351*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
354*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
355*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellBias.data()),
356*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(float) * cellBias.size())));
357*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
358*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
359*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoNumUnits.size()),
360*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
361*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
362*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("cellBias")));
363*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
364*89c4ff92SAndroid Build Coastguard Worker 
365*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(
366*89c4ff92SAndroid Build Coastguard Worker         CreateBuffer(flatBufferBuilder,
367*89c4ff92SAndroid Build Coastguard Worker                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(outputGateBias.data()),
368*89c4ff92SAndroid Build Coastguard Worker                                                     sizeof(float) * outputGateBias.size())));
369*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
370*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
371*89c4ff92SAndroid Build Coastguard Worker                                                                            tensorInfoNumUnits.size()),
372*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
373*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
374*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("outputGateBias")));
375*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker     if (hasProjectionWeights)
378*89c4ff92SAndroid Build Coastguard Worker     {
379*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
380*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
381*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
382*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(projectionWeights.data()),
383*89c4ff92SAndroid Build Coastguard Worker                              sizeof(T) * projectionWeights.size())));
384*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
385*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
386*89c4ff92SAndroid Build Coastguard Worker                                                                                projectionWeightDimensions.size()),
387*89c4ff92SAndroid Build Coastguard Worker                                        tensorType,
388*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
389*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("projectionWeights"),
390*89c4ff92SAndroid Build Coastguard Worker                                        weightQuantizationParameters));
391*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
392*89c4ff92SAndroid Build Coastguard Worker     }
393*89c4ff92SAndroid Build Coastguard Worker     else
394*89c4ff92SAndroid Build Coastguard Worker     {
395*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
396*89c4ff92SAndroid Build Coastguard Worker     }
397*89c4ff92SAndroid Build Coastguard Worker 
398*89c4ff92SAndroid Build Coastguard Worker     if (hasProjectionBias)
399*89c4ff92SAndroid Build Coastguard Worker     {
400*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
401*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
402*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
403*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(projectionBias.data()),
404*89c4ff92SAndroid Build Coastguard Worker                              sizeof(float) * projectionBias.size())));
405*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
406*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
407*89c4ff92SAndroid Build Coastguard Worker                                                                                projectionBiasDimensions.size()),
408*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
409*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
410*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("projectionBias")));
411*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
412*89c4ff92SAndroid Build Coastguard Worker     }
413*89c4ff92SAndroid Build Coastguard Worker     else
414*89c4ff92SAndroid Build Coastguard Worker     {
415*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
416*89c4ff92SAndroid Build Coastguard Worker     }
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
419*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
420*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
421*89c4ff92SAndroid Build Coastguard Worker                                                                            outputStateInDimensions.size()),
422*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
423*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
424*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("outputStateInInfo"),
425*89c4ff92SAndroid Build Coastguard Worker                                    quantizationParameters,
426*89c4ff92SAndroid Build Coastguard Worker                                    true));
427*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
430*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
431*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
432*89c4ff92SAndroid Build Coastguard Worker                                                                            cellStateInDimensions.size()),
433*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
434*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
435*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("cellStateInInfo"),
436*89c4ff92SAndroid Build Coastguard Worker                                    quantizationParameters,
437*89c4ff92SAndroid Build Coastguard Worker                                    true));
438*89c4ff92SAndroid Build Coastguard Worker     operatorInputs.push_back(tensors.size() - 1);
439*89c4ff92SAndroid Build Coastguard Worker 
440*89c4ff92SAndroid Build Coastguard Worker     if (hasInputLayerNormWeights)
441*89c4ff92SAndroid Build Coastguard Worker     {
442*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
443*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
444*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
445*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(inputLayerNormWeights.data()),
446*89c4ff92SAndroid Build Coastguard Worker                              sizeof(float) * inputLayerNormWeights.size())));
447*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
448*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
449*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
450*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
451*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
452*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("inputLayerNormWeights")));
453*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
454*89c4ff92SAndroid Build Coastguard Worker     }
455*89c4ff92SAndroid Build Coastguard Worker     else
456*89c4ff92SAndroid Build Coastguard Worker     {
457*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
458*89c4ff92SAndroid Build Coastguard Worker     }
459*89c4ff92SAndroid Build Coastguard Worker 
460*89c4ff92SAndroid Build Coastguard Worker     if (hasForgetLayerNormWeights)
461*89c4ff92SAndroid Build Coastguard Worker     {
462*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
463*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
464*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
465*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(forgetLayerNormWeights.data()),
466*89c4ff92SAndroid Build Coastguard Worker                              sizeof(float) * forgetLayerNormWeights.size())));
467*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
468*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
469*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
470*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
471*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
472*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("forgetLayerNormWeights")));
473*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
474*89c4ff92SAndroid Build Coastguard Worker     }
475*89c4ff92SAndroid Build Coastguard Worker     else
476*89c4ff92SAndroid Build Coastguard Worker     {
477*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
478*89c4ff92SAndroid Build Coastguard Worker     }
479*89c4ff92SAndroid Build Coastguard Worker 
480*89c4ff92SAndroid Build Coastguard Worker     if (hasCellLayerNormWeights)
481*89c4ff92SAndroid Build Coastguard Worker     {
482*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
483*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
484*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
485*89c4ff92SAndroid Build Coastguard Worker                                                             cellLayerNormWeights.data()),
486*89c4ff92SAndroid Build Coastguard Worker                                                         sizeof(float) * cellLayerNormWeights.size())));
487*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
488*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
489*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
490*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
491*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
492*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("cellLayerNormWeights")));
493*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
494*89c4ff92SAndroid Build Coastguard Worker     }
495*89c4ff92SAndroid Build Coastguard Worker     else
496*89c4ff92SAndroid Build Coastguard Worker     {
497*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
498*89c4ff92SAndroid Build Coastguard Worker     }
499*89c4ff92SAndroid Build Coastguard Worker 
500*89c4ff92SAndroid Build Coastguard Worker     if (hasOutputLayerNormWeights)
501*89c4ff92SAndroid Build Coastguard Worker     {
502*89c4ff92SAndroid Build Coastguard Worker         buffers.push_back(
503*89c4ff92SAndroid Build Coastguard Worker             CreateBuffer(flatBufferBuilder,
504*89c4ff92SAndroid Build Coastguard Worker                          flatBufferBuilder.CreateVector(
505*89c4ff92SAndroid Build Coastguard Worker                              reinterpret_cast<const uint8_t*>(outputLayerNormWeights.data()),
506*89c4ff92SAndroid Build Coastguard Worker                              sizeof(float) * outputLayerNormWeights.size())));
507*89c4ff92SAndroid Build Coastguard Worker         tensors.push_back(CreateTensor(flatBufferBuilder,
508*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
509*89c4ff92SAndroid Build Coastguard Worker                                                                                tensorInfoNumUnits.size()),
510*89c4ff92SAndroid Build Coastguard Worker                                        ::tflite::TensorType_FLOAT32,
511*89c4ff92SAndroid Build Coastguard Worker                                        buffers.size() - 1,
512*89c4ff92SAndroid Build Coastguard Worker                                        flatBufferBuilder.CreateString("outputLayerNormWeights")));
513*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(tensors.size() - 1);
514*89c4ff92SAndroid Build Coastguard Worker     }
515*89c4ff92SAndroid Build Coastguard Worker     else
516*89c4ff92SAndroid Build Coastguard Worker     {
517*89c4ff92SAndroid Build Coastguard Worker         operatorInputs.push_back(kTfLiteOptionalTensor);
518*89c4ff92SAndroid Build Coastguard Worker     }
519*89c4ff92SAndroid Build Coastguard Worker     buffers.push_back(CreateBuffer(flatBufferBuilder));
520*89c4ff92SAndroid Build Coastguard Worker     tensors.push_back(CreateTensor(flatBufferBuilder,
521*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
522*89c4ff92SAndroid Build Coastguard Worker                                                                            outputShape.size()),
523*89c4ff92SAndroid Build Coastguard Worker                                    ::tflite::TensorType_FLOAT32,
524*89c4ff92SAndroid Build Coastguard Worker                                    buffers.size() - 1,
525*89c4ff92SAndroid Build Coastguard Worker                                    flatBufferBuilder.CreateString("output")));
526*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> operatorOutputs;
527*89c4ff92SAndroid Build Coastguard Worker     operatorOutputs.push_back(tensors.size() - 1);
528*89c4ff92SAndroid Build Coastguard Worker 
529*89c4ff92SAndroid Build Coastguard Worker     // create operator
530*89c4ff92SAndroid Build Coastguard Worker     tflite::BuiltinOptions    operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
531*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<void> operatorBuiltinOptions     =
532*89c4ff92SAndroid Build Coastguard Worker                                   CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
533*89c4ff92SAndroid Build Coastguard Worker                                                                           activationFunction,
534*89c4ff92SAndroid Build Coastguard Worker                                                                           clippingThresCell,
535*89c4ff92SAndroid Build Coastguard Worker                                                                           clippingThresProj,
536*89c4ff92SAndroid Build Coastguard Worker                                                                           isTimeMajor).Union();
537*89c4ff92SAndroid Build Coastguard Worker 
538*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<Operator> lstmOperator =
539*89c4ff92SAndroid Build Coastguard Worker                                       CreateOperator(flatBufferBuilder,
540*89c4ff92SAndroid Build Coastguard Worker                                                      0,
541*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
542*89c4ff92SAndroid Build Coastguard Worker                                                                                              operatorInputs.size()),
543*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
544*89c4ff92SAndroid Build Coastguard Worker                                                                                              operatorOutputs.size()),
545*89c4ff92SAndroid Build Coastguard Worker                                                      operatorBuiltinOptionsType, operatorBuiltinOptions);
546*89c4ff92SAndroid Build Coastguard Worker 
547*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<SubGraph> subgraph =
548*89c4ff92SAndroid Build Coastguard Worker                                       CreateSubGraph(flatBufferBuilder,
549*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
550*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
551*89c4ff92SAndroid Build Coastguard Worker                                                                                              operatorInputs.size()),
552*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
553*89c4ff92SAndroid Build Coastguard Worker                                                                                              operatorOutputs.size()),
554*89c4ff92SAndroid Build Coastguard Worker                                                      flatBufferBuilder.CreateVector(&lstmOperator, 1));
555*89c4ff92SAndroid Build Coastguard Worker 
556*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<flatbuffers::String> modelDescription =
557*89c4ff92SAndroid Build Coastguard Worker                                                  flatBufferBuilder.CreateString(
558*89c4ff92SAndroid Build Coastguard Worker                                                      "ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
559*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<OperatorCode> operatorCode =
560*89c4ff92SAndroid Build Coastguard Worker                                                  CreateOperatorCode(flatBufferBuilder,
561*89c4ff92SAndroid Build Coastguard Worker                                                  tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
562*89c4ff92SAndroid Build Coastguard Worker 
563*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Offset<Model> flatbufferModel =
564*89c4ff92SAndroid Build Coastguard Worker                                    CreateModel(flatBufferBuilder,
565*89c4ff92SAndroid Build Coastguard Worker                                                TFLITE_SCHEMA_VERSION,
566*89c4ff92SAndroid Build Coastguard Worker                                                flatBufferBuilder.CreateVector(&operatorCode, 1),
567*89c4ff92SAndroid Build Coastguard Worker                                                flatBufferBuilder.CreateVector(&subgraph, 1),
568*89c4ff92SAndroid Build Coastguard Worker                                                modelDescription,
569*89c4ff92SAndroid Build Coastguard Worker                                                flatBufferBuilder.CreateVector(buffers));
570*89c4ff92SAndroid Build Coastguard Worker 
571*89c4ff92SAndroid Build Coastguard Worker     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
572*89c4ff92SAndroid Build Coastguard Worker 
573*89c4ff92SAndroid Build Coastguard Worker     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
574*89c4ff92SAndroid Build Coastguard Worker                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
575*89c4ff92SAndroid Build Coastguard Worker }
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker template<typename T>
UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId> & backends,tflite::TensorType tensorType,int32_t batchSize,int32_t timeSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<float> & inputGateBias,const std::vector<float> & forgetGateBias,const std::vector<float> & cellBias,const std::vector<float> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<float> & projectionBias,bool hasInputLayerNormWeights,const std::vector<float> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<float> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<float> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<float> & outputLayerNormWeights,std::vector<float> & inputValues,std::vector<float> & expectedOutputValues,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,bool isTimeMajor,float quantScale=0.1f)578*89c4ff92SAndroid Build Coastguard Worker void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
579*89c4ff92SAndroid Build Coastguard Worker                                         tflite::TensorType tensorType,
580*89c4ff92SAndroid Build Coastguard Worker                                         int32_t batchSize,
581*89c4ff92SAndroid Build Coastguard Worker                                         int32_t timeSize,
582*89c4ff92SAndroid Build Coastguard Worker                                         int32_t inputSize,
583*89c4ff92SAndroid Build Coastguard Worker                                         int32_t outputSize,
584*89c4ff92SAndroid Build Coastguard Worker                                         int32_t numUnits,
585*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputToInputWeights,
586*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToInputWeights,
587*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToForgetWeights,
588*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToCellWeights,
589*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& inputToOutputWeights,
590*89c4ff92SAndroid Build Coastguard Worker                                         bool hasRecurrentToInputWeights,
591*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToInputWeights,
592*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToForgetWeights,
593*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToCellWeights,
594*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& recurrentToOutputWeights,
595*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToInputWeights,
596*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToInputWeights,
597*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToForgetWeights,
598*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToForgetWeights,
599*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellToOutputWeights,
600*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& cellToOutputWeights,
601*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputGateBias,
602*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& inputGateBias,
603*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& forgetGateBias,
604*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& cellBias,
605*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& outputGateBias,
606*89c4ff92SAndroid Build Coastguard Worker                                         bool hasProjectionWeights,
607*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<T>& projectionWeights,
608*89c4ff92SAndroid Build Coastguard Worker                                         bool hasProjectionBias,
609*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& projectionBias,
610*89c4ff92SAndroid Build Coastguard Worker                                         bool hasInputLayerNormWeights,
611*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& inputLayerNormWeights,
612*89c4ff92SAndroid Build Coastguard Worker                                         bool hasForgetLayerNormWeights,
613*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& forgetLayerNormWeights,
614*89c4ff92SAndroid Build Coastguard Worker                                         bool hasCellLayerNormWeights,
615*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& cellLayerNormWeights,
616*89c4ff92SAndroid Build Coastguard Worker                                         bool hasOutputLayerNormWeights,
617*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<float>& outputLayerNormWeights,
618*89c4ff92SAndroid Build Coastguard Worker                                         std::vector<float>& inputValues,
619*89c4ff92SAndroid Build Coastguard Worker                                         std::vector<float>& expectedOutputValues,
620*89c4ff92SAndroid Build Coastguard Worker                                         tflite::ActivationFunctionType activationFunction,
621*89c4ff92SAndroid Build Coastguard Worker                                         float clippingThresCell,
622*89c4ff92SAndroid Build Coastguard Worker                                         float clippingThresProj,
623*89c4ff92SAndroid Build Coastguard Worker                                         bool isTimeMajor,
624*89c4ff92SAndroid Build Coastguard Worker                                         float quantScale = 0.1f)
625*89c4ff92SAndroid Build Coastguard Worker {
626*89c4ff92SAndroid Build Coastguard Worker     using namespace delegateTestInterpreter;
627*89c4ff92SAndroid Build Coastguard Worker 
628*89c4ff92SAndroid Build Coastguard Worker     std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
629*89c4ff92SAndroid Build Coastguard Worker                                                                                 batchSize,
630*89c4ff92SAndroid Build Coastguard Worker                                                                                 timeSize,
631*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputSize,
632*89c4ff92SAndroid Build Coastguard Worker                                                                                 outputSize,
633*89c4ff92SAndroid Build Coastguard Worker                                                                                 numUnits,
634*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasInputToInputWeights,
635*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputToInputWeights,
636*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputToForgetWeights,
637*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputToCellWeights,
638*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputToOutputWeights,
639*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasRecurrentToInputWeights,
640*89c4ff92SAndroid Build Coastguard Worker                                                                                 recurrentToInputWeights,
641*89c4ff92SAndroid Build Coastguard Worker                                                                                 recurrentToForgetWeights,
642*89c4ff92SAndroid Build Coastguard Worker                                                                                 recurrentToCellWeights,
643*89c4ff92SAndroid Build Coastguard Worker                                                                                 recurrentToOutputWeights,
644*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasCellToInputWeights,
645*89c4ff92SAndroid Build Coastguard Worker                                                                                 cellToInputWeights,
646*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasCellToForgetWeights,
647*89c4ff92SAndroid Build Coastguard Worker                                                                                 cellToForgetWeights,
648*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasCellToOutputWeights,
649*89c4ff92SAndroid Build Coastguard Worker                                                                                 cellToOutputWeights,
650*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasInputGateBias,
651*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputGateBias,
652*89c4ff92SAndroid Build Coastguard Worker                                                                                 forgetGateBias,
653*89c4ff92SAndroid Build Coastguard Worker                                                                                 cellBias,
654*89c4ff92SAndroid Build Coastguard Worker                                                                                 outputGateBias,
655*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasProjectionWeights,
656*89c4ff92SAndroid Build Coastguard Worker                                                                                 projectionWeights,
657*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasProjectionBias,
658*89c4ff92SAndroid Build Coastguard Worker                                                                                 projectionBias,
659*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasInputLayerNormWeights,
660*89c4ff92SAndroid Build Coastguard Worker                                                                                 inputLayerNormWeights,
661*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasForgetLayerNormWeights,
662*89c4ff92SAndroid Build Coastguard Worker                                                                                 forgetLayerNormWeights,
663*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasCellLayerNormWeights,
664*89c4ff92SAndroid Build Coastguard Worker                                                                                 cellLayerNormWeights,
665*89c4ff92SAndroid Build Coastguard Worker                                                                                 hasOutputLayerNormWeights,
666*89c4ff92SAndroid Build Coastguard Worker                                                                                 outputLayerNormWeights,
667*89c4ff92SAndroid Build Coastguard Worker                                                                                 activationFunction,
668*89c4ff92SAndroid Build Coastguard Worker                                                                                 clippingThresCell,
669*89c4ff92SAndroid Build Coastguard Worker                                                                                 clippingThresProj,
670*89c4ff92SAndroid Build Coastguard Worker                                                                                 isTimeMajor,
671*89c4ff92SAndroid Build Coastguard Worker                                                                                 quantScale);
672*89c4ff92SAndroid Build Coastguard Worker 
673*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape;
674*89c4ff92SAndroid Build Coastguard Worker     if (isTimeMajor)
675*89c4ff92SAndroid Build Coastguard Worker     {
676*89c4ff92SAndroid Build Coastguard Worker         outputShape = {timeSize, batchSize, outputSize};
677*89c4ff92SAndroid Build Coastguard Worker     }
678*89c4ff92SAndroid Build Coastguard Worker     else
679*89c4ff92SAndroid Build Coastguard Worker     {
680*89c4ff92SAndroid Build Coastguard Worker         outputShape = {batchSize, timeSize, outputSize};
681*89c4ff92SAndroid Build Coastguard Worker     }
682*89c4ff92SAndroid Build Coastguard Worker 
683*89c4ff92SAndroid Build Coastguard Worker     // Setup interpreter with just TFLite Runtime.
684*89c4ff92SAndroid Build Coastguard Worker     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
685*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
686*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
687*89c4ff92SAndroid Build Coastguard Worker     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
688*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
689*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
690*89c4ff92SAndroid Build Coastguard Worker 
691*89c4ff92SAndroid Build Coastguard Worker     // Setup interpreter with Arm NN Delegate applied.
692*89c4ff92SAndroid Build Coastguard Worker     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
693*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
694*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
695*89c4ff92SAndroid Build Coastguard Worker     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
696*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
697*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
698*89c4ff92SAndroid Build Coastguard Worker 
699*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
700*89c4ff92SAndroid Build Coastguard Worker 
701*89c4ff92SAndroid Build Coastguard Worker     if (tensorType == ::tflite::TensorType_INT8)
702*89c4ff92SAndroid Build Coastguard Worker     {
703*89c4ff92SAndroid Build Coastguard Worker         // Allow 2% tolerance for Quantized weights
704*89c4ff92SAndroid Build Coastguard Worker         armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
705*89c4ff92SAndroid Build Coastguard Worker                                    expectedOutputValues.size(), 2);
706*89c4ff92SAndroid Build Coastguard Worker         armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
707*89c4ff92SAndroid Build Coastguard Worker                                    expectedOutputValues.size(), 2);
708*89c4ff92SAndroid Build Coastguard Worker         armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
709*89c4ff92SAndroid Build Coastguard Worker                                    expectedOutputValues.size(), 2);
710*89c4ff92SAndroid Build Coastguard Worker     }
711*89c4ff92SAndroid Build Coastguard Worker     else
712*89c4ff92SAndroid Build Coastguard Worker     {
713*89c4ff92SAndroid Build Coastguard Worker         armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
714*89c4ff92SAndroid Build Coastguard Worker     }
715*89c4ff92SAndroid Build Coastguard Worker 
716*89c4ff92SAndroid Build Coastguard Worker     tfLiteInterpreter.Cleanup();
717*89c4ff92SAndroid Build Coastguard Worker     armnnInterpreter.Cleanup();
718*89c4ff92SAndroid Build Coastguard Worker }
719*89c4ff92SAndroid Build Coastguard Worker 
720*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace