xref: /aosp_15_r20/external/armnn/delegate/test/LstmTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12 
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16 
17 #include <schema_generated.h>
18 
19 #include <doctest/doctest.h>
20 
21 namespace
22 {
23 
24 template <typename T>
CreateLstmTfLiteModel(tflite::TensorType tensorType,int32_t batchSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<T> & inputGateBias,const std::vector<T> & forgetGateBias,const std::vector<T> & cellBias,const std::vector<T> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<T> & projectionBias,bool hasInputLayerNormWeights,const std::vector<T> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<T> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<T> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<T> & outputLayerNormWeights,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,float quantScale=1.0f,int quantOffset=0,float outputQuantScale=2.0f,int outputQuantOffset=0)25 std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
26                                         int32_t batchSize,
27                                         int32_t inputSize,
28                                         int32_t outputSize,
29                                         int32_t numUnits,
30                                         bool hasInputToInputWeights,
31                                         const std::vector<T>& inputToInputWeights,
32                                         const std::vector<T>& inputToForgetWeights,
33                                         const std::vector<T>& inputToCellWeights,
34                                         const std::vector<T>& inputToOutputWeights,
35                                         bool hasRecurrentToInputWeights,
36                                         const std::vector<T>& recurrentToInputWeights,
37                                         const std::vector<T>& recurrentToForgetWeights,
38                                         const std::vector<T>& recurrentToCellWeights,
39                                         const std::vector<T>& recurrentToOutputWeights,
40                                         bool hasCellToInputWeights,
41                                         const std::vector<T>& cellToInputWeights,
42                                         bool hasCellToForgetWeights,
43                                         const std::vector<T>& cellToForgetWeights,
44                                         bool hasCellToOutputWeights,
45                                         const std::vector<T>& cellToOutputWeights,
46                                         bool hasInputGateBias,
47                                         const std::vector<T>& inputGateBias,
48                                         const std::vector<T>& forgetGateBias,
49                                         const std::vector<T>& cellBias,
50                                         const std::vector<T>& outputGateBias,
51                                         bool hasProjectionWeights,
52                                         const std::vector<T>& projectionWeights,
53                                         bool hasProjectionBias,
54                                         const std::vector<T>& projectionBias,
55                                         bool hasInputLayerNormWeights,
56                                         const std::vector<T>& inputLayerNormWeights,
57                                         bool hasForgetLayerNormWeights,
58                                         const std::vector<T>& forgetLayerNormWeights,
59                                         bool hasCellLayerNormWeights,
60                                         const std::vector<T>& cellLayerNormWeights,
61                                         bool hasOutputLayerNormWeights,
62                                         const std::vector<T>& outputLayerNormWeights,
63                                         tflite::ActivationFunctionType activationFunction,
64                                         float clippingThresCell,
65                                         float clippingThresProj,
66                                         float quantScale = 1.0f,
67                                         int quantOffset  = 0,
68                                         float outputQuantScale = 2.0f,
69                                         int outputQuantOffset  = 0)
70 {
71 
72     std::vector <int32_t> tensorInfo0 {};
73     std::vector <int32_t> tensorInfo4 {numUnits};
74     std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
75     std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
76 
77     std::vector<int32_t> inputShape {batchSize , inputSize};
78     std::vector<int32_t> outputShape {batchSize , outputSize};
79 
80     std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
81     std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
82 
83     std::vector<int> operatorInputs;
84     using namespace tflite;
85     flatbuffers::FlatBufferBuilder flatBufferBuilder;
86     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
87     std::vector<flatbuffers::Offset<Tensor>> tensors;
88 
89     auto quantizationParameters =
90         CreateQuantizationParameters(flatBufferBuilder,
91                                      0,
92                                      0,
93                                      flatBufferBuilder.CreateVector<float>({ quantScale }),
94                                      flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
95 
96     auto outputQuantizationParameters =
97         CreateQuantizationParameters(flatBufferBuilder,
98                                      0,
99                                      0,
100                                      flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
101                                      flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
102 
103     buffers.push_back(CreateBuffer(flatBufferBuilder));
104     tensors.push_back(CreateTensor(flatBufferBuilder,
105                                    flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
106                                                                            inputShape.size()),
107                                    tensorType,
108                                    buffers.size() - 1,
109                                    flatBufferBuilder.CreateString("input_0"),
110                                    quantizationParameters));
111     operatorInputs.push_back(buffers.size() - 1);
112 
113     if (hasInputToInputWeights)
114     {
115         buffers.push_back(
116             CreateBuffer(flatBufferBuilder,
117                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
118                                                         sizeof(T) * inputToInputWeights.size())));
119         tensors.push_back(CreateTensor(flatBufferBuilder,
120                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
121                                                                                tensorInfo8.size()),
122                                        tensorType,
123                                        buffers.size() - 1,
124                                        flatBufferBuilder.CreateString("inputToInputWeights"),
125                                        outputQuantizationParameters));
126         operatorInputs.push_back(buffers.size() - 1);
127     }
128     else
129     {
130         operatorInputs.push_back(kTfLiteOptionalTensor);
131     }
132 
133     buffers.push_back(
134         CreateBuffer(flatBufferBuilder,
135                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
136                                                     sizeof(T) * inputToForgetWeights.size())));
137     tensors.push_back(CreateTensor(flatBufferBuilder,
138                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
139                                                                            tensorInfo8.size()),
140                                    tensorType,
141                                    buffers.size() - 1,
142                                    flatBufferBuilder.CreateString("inputToForgetWeights"),
143                                    outputQuantizationParameters));
144     operatorInputs.push_back(buffers.size() - 1);
145 
146     buffers.push_back(
147         CreateBuffer(flatBufferBuilder,
148                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
149                                                     sizeof(T) * inputToCellWeights.size())));
150     tensors.push_back(CreateTensor(flatBufferBuilder,
151                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
152                                                                            tensorInfo8.size()),
153                                    tensorType,
154                                    buffers.size() - 1,
155                                    flatBufferBuilder.CreateString("inputToCellWeights"),
156                                    outputQuantizationParameters));
157     operatorInputs.push_back(buffers.size() - 1);
158 
159     buffers.push_back(
160         CreateBuffer(flatBufferBuilder,
161                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
162                                                     sizeof(T) * inputToOutputWeights.size())));
163     tensors.push_back(CreateTensor(flatBufferBuilder,
164                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
165                                                                            tensorInfo8.size()),
166                                    tensorType,
167                                    buffers.size() - 1,
168                                    flatBufferBuilder.CreateString("inputToOutputWeights"),
169                                    outputQuantizationParameters));
170     operatorInputs.push_back(buffers.size() - 1);
171 
172     if (hasRecurrentToInputWeights)
173     {
174         buffers.push_back(CreateBuffer(
175             flatBufferBuilder,
176             flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
177                                            sizeof(T) * recurrentToInputWeights.size())));
178         tensors.push_back(CreateTensor(flatBufferBuilder,
179                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
180                                                                                tensorInfo16.size()),
181                                        tensorType,
182                                        buffers.size() - 1,
183                                        flatBufferBuilder.CreateString("recurrentToInputWeights"),
184                                        outputQuantizationParameters));
185         operatorInputs.push_back(buffers.size() - 1);
186     }
187     else
188     {
189         operatorInputs.push_back(kTfLiteOptionalTensor);
190     }
191 
192     buffers.push_back(
193         CreateBuffer(flatBufferBuilder,
194                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
195                                                     sizeof(T) * recurrentToForgetWeights.size())));
196     tensors.push_back(CreateTensor(flatBufferBuilder,
197                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
198                                                                            tensorInfo16.size()),
199                                    tensorType,
200                                    buffers.size() - 1,
201                                    flatBufferBuilder.CreateString("recurrentToForgetWeights"),
202                                    outputQuantizationParameters));
203     operatorInputs.push_back(buffers.size() - 1);
204 
205     buffers.push_back(
206         CreateBuffer(flatBufferBuilder,
207                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
208                                                     sizeof(T) * recurrentToCellWeights.size())));
209     tensors.push_back(CreateTensor(flatBufferBuilder,
210                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
211                                                                            tensorInfo16.size()),
212                                    tensorType,
213                                    buffers.size() - 1,
214                                    flatBufferBuilder.CreateString("recurrentToCellWeights"),
215                                    outputQuantizationParameters));
216     operatorInputs.push_back(buffers.size() - 1);
217 
218     buffers.push_back(
219         CreateBuffer(flatBufferBuilder,
220                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
221                                                     sizeof(T) * recurrentToOutputWeights.size())));
222     tensors.push_back(CreateTensor(flatBufferBuilder,
223                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
224                                                                            tensorInfo16.size()),
225                                    tensorType,
226                                    buffers.size() - 1 ,
227                                    flatBufferBuilder.CreateString("recurrentToOutputWeights"),
228                                    outputQuantizationParameters));
229     operatorInputs.push_back(buffers.size() - 1);
230 
231     if (hasCellToInputWeights)
232     {
233         buffers.push_back(
234             CreateBuffer(flatBufferBuilder,
235                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
236                                                         sizeof(T) * cellToInputWeights.size())));
237         tensors.push_back(CreateTensor(flatBufferBuilder,
238                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
239                                                                                tensorInfo4.size()),
240                                        tensorType,
241                                        buffers.size() - 1,
242                                        flatBufferBuilder.CreateString("cellToInputWeights"),
243                                        outputQuantizationParameters));
244         operatorInputs.push_back(buffers.size() - 1);
245     }
246     else
247     {
248         operatorInputs.push_back(kTfLiteOptionalTensor);
249     }
250 
251     if (hasCellToForgetWeights)
252     {
253         buffers.push_back(
254             CreateBuffer(flatBufferBuilder,
255                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
256                                                         sizeof(T) * cellToForgetWeights.size())));
257         tensors.push_back(CreateTensor(flatBufferBuilder,
258                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
259                                                                                tensorInfo4.size()),
260                                        tensorType,
261                                        buffers.size() - 1,
262                                        flatBufferBuilder.CreateString("cellToForgetWeights"),
263                                        outputQuantizationParameters));
264         operatorInputs.push_back(buffers.size() - 1);
265     }
266     else
267     {
268         operatorInputs.push_back(kTfLiteOptionalTensor);
269     }
270 
271     if (hasCellToOutputWeights)
272     {
273         buffers.push_back(
274             CreateBuffer(flatBufferBuilder,
275                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
276                                                         sizeof(T) * cellToOutputWeights.size())));
277         tensors.push_back(CreateTensor(flatBufferBuilder,
278                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
279                                                                                tensorInfo4.size()),
280                                        tensorType,
281                                        buffers.size() - 1,
282                                        flatBufferBuilder.CreateString("cellToOutputWeights"),
283                                        outputQuantizationParameters));
284         operatorInputs.push_back(buffers.size() - 1);
285     }
286     else
287     {
288         operatorInputs.push_back(kTfLiteOptionalTensor);
289     }
290 
291     if (hasInputGateBias)
292     {
293         buffers.push_back(
294             CreateBuffer(flatBufferBuilder,
295                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
296                                                         sizeof(T) * inputGateBias.size())));
297         tensors.push_back(CreateTensor(flatBufferBuilder,
298                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
299                                                                                tensorInfo4.size()),
300                                        tensorType,
301                                        buffers.size() - 1,
302                                        flatBufferBuilder.CreateString("inputGateBias"),
303                                        outputQuantizationParameters));
304         operatorInputs.push_back(buffers.size() - 1);
305     }
306     else
307     {
308         operatorInputs.push_back(kTfLiteOptionalTensor);
309     }
310 
311     buffers.push_back(
312         CreateBuffer(flatBufferBuilder,
313                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
314                                                     sizeof(T) * forgetGateBias.size())));
315     tensors.push_back(CreateTensor(flatBufferBuilder,
316                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
317                                                                            tensorInfo4.size()),
318                                    tensorType,
319                                    buffers.size() - 1,
320                                    flatBufferBuilder.CreateString("forgetGateBias"),
321                                    outputQuantizationParameters));
322     operatorInputs.push_back(buffers.size() - 1);
323 
324     buffers.push_back(
325         CreateBuffer(flatBufferBuilder,
326                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
327                                                     sizeof(T) * cellBias.size())));
328     tensors.push_back(CreateTensor(flatBufferBuilder,
329                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
330                                                                            tensorInfo4.size()),
331                                    tensorType,
332                                    buffers.size() - 1,
333                                    flatBufferBuilder.CreateString("cellBias"),
334                                    outputQuantizationParameters));
335     operatorInputs.push_back(buffers.size() - 1);
336 
337     buffers.push_back(
338         CreateBuffer(flatBufferBuilder,
339                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
340                                                     sizeof(T) * outputGateBias.size())));
341     tensors.push_back(CreateTensor(flatBufferBuilder,
342                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
343                                                                            tensorInfo4.size()),
344                                    tensorType,
345                                    buffers.size() - 1,
346                                    flatBufferBuilder.CreateString("outputGateBias"),
347                                    outputQuantizationParameters));
348     operatorInputs.push_back(buffers.size() - 1);
349 
350     if (hasProjectionWeights)
351     {
352         buffers.push_back(
353             CreateBuffer(flatBufferBuilder,
354                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
355                                                         sizeof(T) * projectionWeights.size())));
356         tensors.push_back(CreateTensor(flatBufferBuilder,
357                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
358                                                                                tensorInfo4.size()),
359                                        tensorType,
360                                        buffers.size() - 1,
361                                        flatBufferBuilder.CreateString("outputGateBias"),
362                                        outputQuantizationParameters));
363         operatorInputs.push_back(buffers.size() - 1);
364     }
365     else
366     {
367         operatorInputs.push_back(kTfLiteOptionalTensor);
368     }
369 
370     if (hasProjectionBias)
371     {
372         buffers.push_back(
373             CreateBuffer(flatBufferBuilder,
374                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
375                                                         sizeof(T) * projectionBias.size())));
376         tensors.push_back(CreateTensor(flatBufferBuilder,
377                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
378                                                                                tensorInfo4.size()),
379                                        tensorType,
380                                        buffers.size() - 1,
381                                        flatBufferBuilder.CreateString("projectionBias"),
382                                        outputQuantizationParameters));
383         operatorInputs.push_back(buffers.size() - 1);
384     }
385     else
386     {
387         operatorInputs.push_back(kTfLiteOptionalTensor);
388     }
389 
390     buffers.push_back(CreateBuffer(flatBufferBuilder));
391     tensors.push_back(CreateTensor(flatBufferBuilder,
392                                    flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
393                                                                            outputStateInDimensions.size()),
394                                    tensorType,
395                                    buffers.size() - 1,
396                                    flatBufferBuilder.CreateString("outputStateInInfo"),
397                                    outputQuantizationParameters,
398                                    true));
399     operatorInputs.push_back(buffers.size() - 1);
400 
401     buffers.push_back(CreateBuffer(flatBufferBuilder));
402     tensors.push_back(CreateTensor(flatBufferBuilder,
403                                    flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
404                                                                            cellStateInDimensions.size()),
405                                    tensorType,
406                                    buffers.size() - 1,
407                                    flatBufferBuilder.CreateString("cellStateInInfo"),
408                                    outputQuantizationParameters,
409                                    true));
410     operatorInputs.push_back(buffers.size() - 1);
411 
412     if (hasInputLayerNormWeights)
413     {
414         buffers.push_back(
415             CreateBuffer(flatBufferBuilder,
416                          flatBufferBuilder.CreateVector(
417                                               reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
418                                               sizeof(T) * inputLayerNormWeights.size())));
419         tensors.push_back(CreateTensor(flatBufferBuilder,
420                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
421                                                                                tensorInfo4.size()),
422                                        tensorType,
423                                        buffers.size() - 1,
424                                        flatBufferBuilder.CreateString("inputLayerNormWeights"),
425                                        outputQuantizationParameters));
426         operatorInputs.push_back(buffers.size() - 1);
427     }
428     else
429     {
430         operatorInputs.push_back(kTfLiteOptionalTensor);
431     }
432 
433     if (hasForgetLayerNormWeights)
434     {
435         buffers.push_back(
436             CreateBuffer(flatBufferBuilder,
437                          flatBufferBuilder.CreateVector(
438                                               reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
439                                               sizeof(T) * forgetLayerNormWeights.size())));
440         tensors.push_back(CreateTensor(flatBufferBuilder,
441                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
442                                                                                tensorInfo4.size()),
443                                        tensorType,
444                                        buffers.size() - 1,
445                                        flatBufferBuilder.CreateString("forgetLayerNormWeights"),
446                                        outputQuantizationParameters));
447         operatorInputs.push_back(buffers.size() - 1);
448     }
449     else
450     {
451         operatorInputs.push_back(kTfLiteOptionalTensor);
452     }
453 
454     if (hasCellLayerNormWeights)
455     {
456         buffers.push_back(
457             CreateBuffer(flatBufferBuilder,
458                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
459                                                         sizeof(T) * cellLayerNormWeights.size())));
460         tensors.push_back(CreateTensor(flatBufferBuilder,
461                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
462                                                                                tensorInfo4.size()),
463                                        tensorType,
464                                        buffers.size() - 1,
465                                        flatBufferBuilder.CreateString("cellLayerNormWeights"),
466                                        outputQuantizationParameters));
467         operatorInputs.push_back(buffers.size() - 1);
468     }
469     else
470     {
471         operatorInputs.push_back(kTfLiteOptionalTensor);
472     }
473 
474     if (hasOutputLayerNormWeights)
475     {
476         buffers.push_back(
477             CreateBuffer(flatBufferBuilder,
478                          flatBufferBuilder.CreateVector(
479                              reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
480                              sizeof(T) * outputLayerNormWeights.size())));
481         tensors.push_back(CreateTensor(flatBufferBuilder,
482                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
483                                                                                tensorInfo4.size()),
484                                        tensorType,
485                                        buffers.size() - 1,
486                                        flatBufferBuilder.CreateString("outputLayerNormWeights"),
487                                        outputQuantizationParameters));
488         operatorInputs.push_back(buffers.size() - 1);
489     }
490     else
491     {
492         operatorInputs.push_back(kTfLiteOptionalTensor);
493     }
494     int outputBufferId = buffers.size();
495     buffers.push_back(CreateBuffer(flatBufferBuilder));
496     tensors.push_back(CreateTensor(flatBufferBuilder,
497                                    flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
498                                                                            outputShape.size()),
499                                    tensorType,
500                                    outputBufferId,
501                                    flatBufferBuilder.CreateString("output"),
502                                    outputQuantizationParameters));
503     std::vector<int> operatorOutputs;
504     operatorOutputs.push_back(buffers.size() - 1);
505 
506     // create operator
507     tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
508     flatbuffers::Offset<void> operatorBuiltinOptions =
509         CreateLSTMOptions(flatBufferBuilder,
510                           activationFunction,
511                           clippingThresCell,
512                           clippingThresProj).Union();
513 
514     flatbuffers::Offset <Operator> lstmOperator =
515         CreateOperator(flatBufferBuilder,
516                        0,
517                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
518                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
519                        operatorBuiltinOptionsType, operatorBuiltinOptions);
520 
521     flatbuffers::Offset <SubGraph> subgraph =
522         CreateSubGraph(flatBufferBuilder,
523                        flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
524                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
525                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
526                        flatBufferBuilder.CreateVector(&lstmOperator, 1));
527 
528     flatbuffers::Offset <flatbuffers::String> modelDescription =
529         flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
530     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
531                                                                          tflite::BuiltinOperator_LSTM);
532 
533     flatbuffers::Offset <Model> flatbufferModel =
534         CreateModel(flatBufferBuilder,
535                     TFLITE_SCHEMA_VERSION,
536                     flatBufferBuilder.CreateVector(&operatorCode, 1),
537                     flatBufferBuilder.CreateVector(&subgraph, 1),
538                     modelDescription,
539                     flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
540 
541     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
542 
543     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
544                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
545 }
546 
547 template <typename T>
LstmTestImpl(std::vector<armnn::BackendId> & backends,tflite::TensorType tensorType,int32_t batchSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<T> & inputGateBias,const std::vector<T> & forgetGateBias,const std::vector<T> & cellBias,const std::vector<T> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<T> & projectionBias,bool hasInputLayerNormWeights,const std::vector<T> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<T> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<T> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<T> & outputLayerNormWeights,std::vector<T> & inputValues,std::vector<T> & expectedOutputValues,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj)548 void LstmTestImpl(std::vector<armnn::BackendId>& backends,
549                   tflite::TensorType tensorType,
550                   int32_t batchSize,
551                   int32_t inputSize,
552                   int32_t outputSize,
553                   int32_t numUnits,
554                   bool hasInputToInputWeights,
555                   const std::vector<T>& inputToInputWeights,
556                   const std::vector<T>& inputToForgetWeights,
557                   const std::vector<T>& inputToCellWeights,
558                   const std::vector<T>& inputToOutputWeights,
559                   bool hasRecurrentToInputWeights,
560                   const std::vector<T>& recurrentToInputWeights,
561                   const std::vector<T>& recurrentToForgetWeights,
562                   const std::vector<T>& recurrentToCellWeights,
563                   const std::vector<T>& recurrentToOutputWeights,
564                   bool hasCellToInputWeights,
565                   const std::vector<T>& cellToInputWeights,
566                   bool hasCellToForgetWeights,
567                   const std::vector<T>& cellToForgetWeights,
568                   bool hasCellToOutputWeights,
569                   const std::vector<T>& cellToOutputWeights,
570                   bool hasInputGateBias,
571                   const std::vector<T>& inputGateBias,
572                   const std::vector<T>& forgetGateBias,
573                   const std::vector<T>& cellBias,
574                   const std::vector<T>& outputGateBias,
575                   bool hasProjectionWeights,
576                   const std::vector<T>& projectionWeights,
577                   bool hasProjectionBias,
578                   const std::vector<T>& projectionBias,
579                   bool hasInputLayerNormWeights,
580                   const std::vector<T>& inputLayerNormWeights,
581                   bool hasForgetLayerNormWeights,
582                   const std::vector<T>& forgetLayerNormWeights,
583                   bool hasCellLayerNormWeights,
584                   const std::vector<T>& cellLayerNormWeights,
585                   bool hasOutputLayerNormWeights,
586                   const std::vector<T>& outputLayerNormWeights,
587                   std::vector<T>& inputValues,
588                   std::vector<T>& expectedOutputValues,
589                   tflite::ActivationFunctionType activationFunction,
590                   float clippingThresCell,
591                   float clippingThresProj)
592 {
593     using namespace delegateTestInterpreter;
594 
595     std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
596                                                           batchSize,
597                                                           inputSize,
598                                                           outputSize,
599                                                           numUnits,
600                                                           hasInputToInputWeights,
601                                                           inputToInputWeights,
602                                                           inputToForgetWeights,
603                                                           inputToCellWeights,
604                                                           inputToOutputWeights,
605                                                           hasRecurrentToInputWeights,
606                                                           recurrentToInputWeights,
607                                                           recurrentToForgetWeights,
608                                                           recurrentToCellWeights,
609                                                           recurrentToOutputWeights,
610                                                           hasCellToInputWeights,
611                                                           cellToInputWeights,
612                                                           hasCellToForgetWeights,
613                                                           cellToForgetWeights,
614                                                           hasCellToOutputWeights,
615                                                           cellToOutputWeights,
616                                                           hasInputGateBias,
617                                                           inputGateBias,
618                                                           forgetGateBias,
619                                                           cellBias,
620                                                           outputGateBias,
621                                                           hasProjectionWeights,
622                                                           projectionWeights,
623                                                           hasProjectionBias,
624                                                           projectionBias,
625                                                           hasInputLayerNormWeights,
626                                                           inputLayerNormWeights,
627                                                           hasForgetLayerNormWeights,
628                                                           forgetLayerNormWeights,
629                                                           hasCellLayerNormWeights,
630                                                           cellLayerNormWeights,
631                                                           hasOutputLayerNormWeights,
632                                                           outputLayerNormWeights,
633                                                           activationFunction,
634                                                           clippingThresCell,
635                                                           clippingThresProj);
636 
637     std::vector<int32_t> expectedOutputShape {batchSize , outputSize};
638 
639     // Setup interpreter with just TFLite Runtime.
640     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
641     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
642     CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
643     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
644     std::vector<T>       tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
645     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
646 
647     // Setup interpreter with Arm NN Delegate applied.
648     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
649     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
650     CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
651     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
652     std::vector<T>       armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
653     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
654 
655     armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
656     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
657 
658     tfLiteInterpreter.Cleanup();
659     armnnInterpreter.Cleanup();
660 }
661 
662 } // anonymous namespace