xref: /aosp_15_r20/external/armnn/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12 
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16 
17 #include <schema_generated.h>
18 
19 #include <doctest/doctest.h>
20 
21 #include <armnn/utility/IgnoreUnused.hpp>
22 #include <armnn/utility/NumericCast.hpp>
23 #include <armnn/TypesUtils.hpp>
24 
25 #include <armnn/Types.hpp>
26 
27 #include <initializer_list>
28 #include <iterator>
29 #include <vector>
30 
31 namespace
32 {
33 
34 template<typename T>
CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,int32_t batchSize,int32_t timeSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<float> & inputGateBias,const std::vector<float> & forgetGateBias,const std::vector<float> & cellBias,const std::vector<float> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<float> & projectionBias,bool hasInputLayerNormWeights,const std::vector<float> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<float> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<float> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<float> & outputLayerNormWeights,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,bool isTimeMajor,float quantScale,int quantOffset=0)35 std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
36                                                               int32_t batchSize,
37                                                               int32_t timeSize,
38                                                               int32_t inputSize,
39                                                               int32_t outputSize,
40                                                               int32_t numUnits,
41                                                               bool hasInputToInputWeights,
42                                                               const std::vector<T>& inputToInputWeights,
43                                                               const std::vector<T>& inputToForgetWeights,
44                                                               const std::vector<T>& inputToCellWeights,
45                                                               const std::vector<T>& inputToOutputWeights,
46                                                               bool hasRecurrentToInputWeights,
47                                                               const std::vector<T>& recurrentToInputWeights,
48                                                               const std::vector<T>& recurrentToForgetWeights,
49                                                               const std::vector<T>& recurrentToCellWeights,
50                                                               const std::vector<T>& recurrentToOutputWeights,
51                                                               bool hasCellToInputWeights,
52                                                               const std::vector<T>& cellToInputWeights,
53                                                               bool hasCellToForgetWeights,
54                                                               const std::vector<T>& cellToForgetWeights,
55                                                               bool hasCellToOutputWeights,
56                                                               const std::vector<T>& cellToOutputWeights,
57                                                               bool hasInputGateBias,
58                                                               const std::vector<float>& inputGateBias,
59                                                               const std::vector<float>& forgetGateBias,
60                                                               const std::vector<float>& cellBias,
61                                                               const std::vector<float>& outputGateBias,
62                                                               bool hasProjectionWeights,
63                                                               const std::vector<T>& projectionWeights,
64                                                               bool hasProjectionBias,
65                                                               const std::vector<float>& projectionBias,
66                                                               bool hasInputLayerNormWeights,
67                                                               const std::vector<float>& inputLayerNormWeights,
68                                                               bool hasForgetLayerNormWeights,
69                                                               const std::vector<float>& forgetLayerNormWeights,
70                                                               bool hasCellLayerNormWeights,
71                                                               const std::vector<float>& cellLayerNormWeights,
72                                                               bool hasOutputLayerNormWeights,
73                                                               const std::vector<float>& outputLayerNormWeights,
74                                                               tflite::ActivationFunctionType activationFunction,
75                                                               float clippingThresCell,
76                                                               float clippingThresProj,
77                                                               bool isTimeMajor,
78                                                               float quantScale,
79                                                               int quantOffset = 0)
80 {
81 
82     std::vector<int32_t> tensorInfo0{};
83     std::vector<int32_t> tensorInfoNumUnits{numUnits};
84     std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
85     std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
86 
87     std::vector<int32_t> inputShape;
88     std::vector<int32_t> outputShape;
89     if (isTimeMajor)
90     {
91         inputShape  = {timeSize, batchSize, inputSize};
92         outputShape = {timeSize, batchSize, outputSize};
93     }
94     else
95     {
96         inputShape  = {batchSize, timeSize, inputSize};
97         outputShape = {batchSize, timeSize, outputSize};
98     }
99     std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
100     std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
101     std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
102     std::vector<int32_t> projectionBiasDimensions{outputSize};
103 
104     std::vector<int> operatorInputs;
105     using namespace tflite;
106     flatbuffers::FlatBufferBuilder                   flatBufferBuilder;
107     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
108     std::vector<flatbuffers::Offset<Tensor>>         tensors;
109 
110     auto quantizationParameters =
111              CreateQuantizationParameters(flatBufferBuilder,
112                                           0,
113                                           0,
114                                           flatBufferBuilder.CreateVector<float>({1.0f}),
115                                           flatBufferBuilder.CreateVector<int64_t>({0}));
116 
117     auto weightQuantizationParameters =
118              CreateQuantizationParameters(flatBufferBuilder,
119                                           0,
120                                           0,
121                                           flatBufferBuilder.CreateVector<float>({quantScale}),
122                                           flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
123 
124     buffers.push_back(CreateBuffer(flatBufferBuilder));
125     buffers.push_back(CreateBuffer(flatBufferBuilder));
126     tensors.push_back(CreateTensor(flatBufferBuilder,
127                                    flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
128                                                                            inputShape.size()),
129                                    ::tflite::TensorType_FLOAT32,
130                                    buffers.size() - 1,
131                                    flatBufferBuilder.CreateString("input_0")));
132     operatorInputs.push_back(tensors.size() - 1);
133 
134     if (hasInputToInputWeights)
135     {
136         buffers.push_back(
137             CreateBuffer(flatBufferBuilder,
138                          flatBufferBuilder.CreateVector(
139                              reinterpret_cast<const uint8_t*>(inputToInputWeights.data()),
140                              sizeof(T) * inputToInputWeights.size())));
141         tensors.push_back(CreateTensor(flatBufferBuilder,
142                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
143                                                                                tensorInfoInputSize.size()),
144                                        tensorType,
145                                        buffers.size() - 1,
146                                        flatBufferBuilder.CreateString("inputToInputWeights"),
147                                        weightQuantizationParameters));
148         operatorInputs.push_back(tensors.size() - 1);
149     }
150     else
151     {
152         operatorInputs.push_back(kTfLiteOptionalTensor);
153     }
154 
155     buffers.push_back(
156         CreateBuffer(flatBufferBuilder,
157                      flatBufferBuilder.CreateVector(
158                          reinterpret_cast<const uint8_t*>(inputToForgetWeights.data()),
159                          sizeof(T) * inputToForgetWeights.size())));
160     tensors.push_back(CreateTensor(flatBufferBuilder,
161                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
162                                                                            tensorInfoInputSize.size()),
163                                    tensorType,
164                                    buffers.size() - 1,
165                                    flatBufferBuilder.CreateString("inputToForgetWeights"),
166                                    weightQuantizationParameters));
167     operatorInputs.push_back(tensors.size() - 1);
168 
169     buffers.push_back(
170         CreateBuffer(flatBufferBuilder,
171                      flatBufferBuilder.CreateVector(
172                          reinterpret_cast<const uint8_t*>(inputToCellWeights.data()),
173                          sizeof(T) * inputToCellWeights.size())));
174     tensors.push_back(CreateTensor(flatBufferBuilder,
175                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
176                                                                            tensorInfoInputSize.size()),
177                                    tensorType,
178                                    buffers.size() - 1,
179                                    flatBufferBuilder.CreateString("inputToCellWeights"),
180                                    weightQuantizationParameters));
181     operatorInputs.push_back(tensors.size() - 1);
182 
183     buffers.push_back(
184         CreateBuffer(flatBufferBuilder,
185                      flatBufferBuilder.CreateVector(
186                          reinterpret_cast<const uint8_t*>(inputToOutputWeights.data()),
187                          sizeof(T) * inputToOutputWeights.size())));
188     tensors.push_back(CreateTensor(flatBufferBuilder,
189                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
190                                                                            tensorInfoInputSize.size()),
191                                    tensorType,
192                                    buffers.size() - 1,
193                                    flatBufferBuilder.CreateString("inputToOutputWeights"),
194                                    weightQuantizationParameters));
195     operatorInputs.push_back(tensors.size() - 1);
196 
197     if (hasRecurrentToInputWeights)
198     {
199         buffers.push_back(CreateBuffer(
200             flatBufferBuilder,
201             flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
202                                            sizeof(T) * recurrentToInputWeights.size())));
203         tensors.push_back(CreateTensor(flatBufferBuilder,
204                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
205                                                                                tensorInfoOutputSize.size()),
206                                        tensorType,
207                                        buffers.size() - 1,
208                                        flatBufferBuilder.CreateString("recurrentToInputWeights"),
209                                        weightQuantizationParameters));
210         operatorInputs.push_back(tensors.size() - 1);
211     }
212     else
213     {
214         operatorInputs.push_back(kTfLiteOptionalTensor);
215     }
216 
217     buffers.push_back(
218         CreateBuffer(flatBufferBuilder,
219                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
220                                                         recurrentToForgetWeights.data()),
221                                                     sizeof(T) * recurrentToForgetWeights.size())));
222     tensors.push_back(CreateTensor(flatBufferBuilder,
223                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
224                                                                            tensorInfoOutputSize.size()),
225                                    tensorType,
226                                    buffers.size() - 1,
227                                    flatBufferBuilder.CreateString("recurrentToForgetWeights"),
228                                    weightQuantizationParameters));
229     operatorInputs.push_back(tensors.size() - 1);
230 
231     buffers.push_back(
232         CreateBuffer(flatBufferBuilder,
233                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
234                                                         recurrentToCellWeights.data()),
235                                                     sizeof(T) * recurrentToCellWeights.size())));
236     tensors.push_back(CreateTensor(flatBufferBuilder,
237                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
238                                                                            tensorInfoOutputSize.size()),
239                                    tensorType,
240                                    buffers.size() - 1,
241                                    flatBufferBuilder.CreateString("recurrentToCellWeights"),
242                                    weightQuantizationParameters));
243     operatorInputs.push_back(tensors.size() - 1);
244 
245     buffers.push_back(
246         CreateBuffer(flatBufferBuilder,
247                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
248                                                         recurrentToOutputWeights.data()),
249                                                     sizeof(T) * recurrentToOutputWeights.size())));
250     tensors.push_back(CreateTensor(flatBufferBuilder,
251                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
252                                                                            tensorInfoOutputSize.size()),
253                                    tensorType,
254                                    buffers.size() - 1,
255                                    flatBufferBuilder.CreateString("recurrentToOutputWeights"),
256                                    weightQuantizationParameters));
257     operatorInputs.push_back(tensors.size() - 1);
258 
259     if (hasCellToInputWeights)
260     {
261         buffers.push_back(
262             CreateBuffer(flatBufferBuilder,
263                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
264                                                             cellToInputWeights.data()),
265                                                         sizeof(T) * cellToInputWeights.size())));
266         tensors.push_back(CreateTensor(flatBufferBuilder,
267                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
268                                                                                tensorInfoNumUnits.size()),
269                                        tensorType,
270                                        buffers.size() - 1,
271                                        flatBufferBuilder.CreateString("cellToInputWeights"),
272                                        weightQuantizationParameters));
273         operatorInputs.push_back(tensors.size() - 1);
274     }
275     else
276     {
277         operatorInputs.push_back(kTfLiteOptionalTensor);
278     }
279 
280     if (hasCellToForgetWeights)
281     {
282         buffers.push_back(
283             CreateBuffer(flatBufferBuilder,
284                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
285                                                             cellToForgetWeights.data()),
286                                                         sizeof(T) * cellToForgetWeights.size())));
287         tensors.push_back(CreateTensor(flatBufferBuilder,
288                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
289                                                                                tensorInfoNumUnits.size()),
290                                        tensorType,
291                                        buffers.size() - 1,
292                                        flatBufferBuilder.CreateString("cellToForgetWeights"),
293                                        weightQuantizationParameters));
294         operatorInputs.push_back(tensors.size() - 1);
295     }
296     else
297     {
298         operatorInputs.push_back(kTfLiteOptionalTensor);
299     }
300 
301     if (hasCellToOutputWeights)
302     {
303         buffers.push_back(
304             CreateBuffer(flatBufferBuilder,
305                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
306                                                             cellToOutputWeights.data()),
307                                                         sizeof(T) * cellToOutputWeights.size())));
308         tensors.push_back(CreateTensor(flatBufferBuilder,
309                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
310                                                                                tensorInfoNumUnits.size()),
311                                        tensorType,
312                                        buffers.size() - 1,
313                                        flatBufferBuilder.CreateString("cellToOutputWeights"),
314                                        weightQuantizationParameters));
315         operatorInputs.push_back(tensors.size() - 1);
316     }
317     else
318     {
319         operatorInputs.push_back(kTfLiteOptionalTensor);
320     }
321 
322     if (hasInputGateBias)
323     {
324         buffers.push_back(
325             CreateBuffer(flatBufferBuilder,
326                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
327                                                         sizeof(float) * inputGateBias.size())));
328         tensors.push_back(CreateTensor(flatBufferBuilder,
329                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
330                                                                                tensorInfoNumUnits.size()),
331                                        ::tflite::TensorType_FLOAT32,
332                                        buffers.size() - 1,
333                                        flatBufferBuilder.CreateString("inputGateBias")));
334         operatorInputs.push_back(tensors.size() - 1);
335     }
336     else
337     {
338         operatorInputs.push_back(kTfLiteOptionalTensor);
339     }
340 
341     buffers.push_back(
342         CreateBuffer(flatBufferBuilder,
343                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(forgetGateBias.data()),
344                                                     sizeof(float) * forgetGateBias.size())));
345     tensors.push_back(CreateTensor(flatBufferBuilder,
346                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
347                                                                            tensorInfoNumUnits.size()),
348                                    ::tflite::TensorType_FLOAT32,
349                                    buffers.size() - 1,
350                                    flatBufferBuilder.CreateString("forgetGateBias")));
351     operatorInputs.push_back(tensors.size() - 1);
352 
353     buffers.push_back(
354         CreateBuffer(flatBufferBuilder,
355                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellBias.data()),
356                                                     sizeof(float) * cellBias.size())));
357     tensors.push_back(CreateTensor(flatBufferBuilder,
358                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
359                                                                            tensorInfoNumUnits.size()),
360                                    ::tflite::TensorType_FLOAT32,
361                                    buffers.size() - 1,
362                                    flatBufferBuilder.CreateString("cellBias")));
363     operatorInputs.push_back(tensors.size() - 1);
364 
365     buffers.push_back(
366         CreateBuffer(flatBufferBuilder,
367                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(outputGateBias.data()),
368                                                     sizeof(float) * outputGateBias.size())));
369     tensors.push_back(CreateTensor(flatBufferBuilder,
370                                    flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
371                                                                            tensorInfoNumUnits.size()),
372                                    ::tflite::TensorType_FLOAT32,
373                                    buffers.size() - 1,
374                                    flatBufferBuilder.CreateString("outputGateBias")));
375     operatorInputs.push_back(tensors.size() - 1);
376 
377     if (hasProjectionWeights)
378     {
379         buffers.push_back(
380             CreateBuffer(flatBufferBuilder,
381                          flatBufferBuilder.CreateVector(
382                              reinterpret_cast<const uint8_t*>(projectionWeights.data()),
383                              sizeof(T) * projectionWeights.size())));
384         tensors.push_back(CreateTensor(flatBufferBuilder,
385                                        flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
386                                                                                projectionWeightDimensions.size()),
387                                        tensorType,
388                                        buffers.size() - 1,
389                                        flatBufferBuilder.CreateString("projectionWeights"),
390                                        weightQuantizationParameters));
391         operatorInputs.push_back(tensors.size() - 1);
392     }
393     else
394     {
395         operatorInputs.push_back(kTfLiteOptionalTensor);
396     }
397 
398     if (hasProjectionBias)
399     {
400         buffers.push_back(
401             CreateBuffer(flatBufferBuilder,
402                          flatBufferBuilder.CreateVector(
403                              reinterpret_cast<const uint8_t*>(projectionBias.data()),
404                              sizeof(float) * projectionBias.size())));
405         tensors.push_back(CreateTensor(flatBufferBuilder,
406                                        flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
407                                                                                projectionBiasDimensions.size()),
408                                        ::tflite::TensorType_FLOAT32,
409                                        buffers.size() - 1,
410                                        flatBufferBuilder.CreateString("projectionBias")));
411         operatorInputs.push_back(tensors.size() - 1);
412     }
413     else
414     {
415         operatorInputs.push_back(kTfLiteOptionalTensor);
416     }
417 
418     buffers.push_back(CreateBuffer(flatBufferBuilder));
419     tensors.push_back(CreateTensor(flatBufferBuilder,
420                                    flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
421                                                                            outputStateInDimensions.size()),
422                                    ::tflite::TensorType_FLOAT32,
423                                    buffers.size() - 1,
424                                    flatBufferBuilder.CreateString("outputStateInInfo"),
425                                    quantizationParameters,
426                                    true));
427     operatorInputs.push_back(tensors.size() - 1);
428 
429     buffers.push_back(CreateBuffer(flatBufferBuilder));
430     tensors.push_back(CreateTensor(flatBufferBuilder,
431                                    flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
432                                                                            cellStateInDimensions.size()),
433                                    ::tflite::TensorType_FLOAT32,
434                                    buffers.size() - 1,
435                                    flatBufferBuilder.CreateString("cellStateInInfo"),
436                                    quantizationParameters,
437                                    true));
438     operatorInputs.push_back(tensors.size() - 1);
439 
440     if (hasInputLayerNormWeights)
441     {
442         buffers.push_back(
443             CreateBuffer(flatBufferBuilder,
444                          flatBufferBuilder.CreateVector(
445                              reinterpret_cast<const uint8_t*>(inputLayerNormWeights.data()),
446                              sizeof(float) * inputLayerNormWeights.size())));
447         tensors.push_back(CreateTensor(flatBufferBuilder,
448                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
449                                                                                tensorInfoNumUnits.size()),
450                                        ::tflite::TensorType_FLOAT32,
451                                        buffers.size() - 1,
452                                        flatBufferBuilder.CreateString("inputLayerNormWeights")));
453         operatorInputs.push_back(tensors.size() - 1);
454     }
455     else
456     {
457         operatorInputs.push_back(kTfLiteOptionalTensor);
458     }
459 
460     if (hasForgetLayerNormWeights)
461     {
462         buffers.push_back(
463             CreateBuffer(flatBufferBuilder,
464                          flatBufferBuilder.CreateVector(
465                              reinterpret_cast<const uint8_t*>(forgetLayerNormWeights.data()),
466                              sizeof(float) * forgetLayerNormWeights.size())));
467         tensors.push_back(CreateTensor(flatBufferBuilder,
468                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
469                                                                                tensorInfoNumUnits.size()),
470                                        ::tflite::TensorType_FLOAT32,
471                                        buffers.size() - 1,
472                                        flatBufferBuilder.CreateString("forgetLayerNormWeights")));
473         operatorInputs.push_back(tensors.size() - 1);
474     }
475     else
476     {
477         operatorInputs.push_back(kTfLiteOptionalTensor);
478     }
479 
480     if (hasCellLayerNormWeights)
481     {
482         buffers.push_back(
483             CreateBuffer(flatBufferBuilder,
484                          flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
485                                                             cellLayerNormWeights.data()),
486                                                         sizeof(float) * cellLayerNormWeights.size())));
487         tensors.push_back(CreateTensor(flatBufferBuilder,
488                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
489                                                                                tensorInfoNumUnits.size()),
490                                        ::tflite::TensorType_FLOAT32,
491                                        buffers.size() - 1,
492                                        flatBufferBuilder.CreateString("cellLayerNormWeights")));
493         operatorInputs.push_back(tensors.size() - 1);
494     }
495     else
496     {
497         operatorInputs.push_back(kTfLiteOptionalTensor);
498     }
499 
500     if (hasOutputLayerNormWeights)
501     {
502         buffers.push_back(
503             CreateBuffer(flatBufferBuilder,
504                          flatBufferBuilder.CreateVector(
505                              reinterpret_cast<const uint8_t*>(outputLayerNormWeights.data()),
506                              sizeof(float) * outputLayerNormWeights.size())));
507         tensors.push_back(CreateTensor(flatBufferBuilder,
508                                        flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
509                                                                                tensorInfoNumUnits.size()),
510                                        ::tflite::TensorType_FLOAT32,
511                                        buffers.size() - 1,
512                                        flatBufferBuilder.CreateString("outputLayerNormWeights")));
513         operatorInputs.push_back(tensors.size() - 1);
514     }
515     else
516     {
517         operatorInputs.push_back(kTfLiteOptionalTensor);
518     }
519     buffers.push_back(CreateBuffer(flatBufferBuilder));
520     tensors.push_back(CreateTensor(flatBufferBuilder,
521                                    flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
522                                                                            outputShape.size()),
523                                    ::tflite::TensorType_FLOAT32,
524                                    buffers.size() - 1,
525                                    flatBufferBuilder.CreateString("output")));
526     std::vector<int> operatorOutputs;
527     operatorOutputs.push_back(tensors.size() - 1);
528 
529     // create operator
530     tflite::BuiltinOptions    operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
531     flatbuffers::Offset<void> operatorBuiltinOptions     =
532                                   CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
533                                                                           activationFunction,
534                                                                           clippingThresCell,
535                                                                           clippingThresProj,
536                                                                           isTimeMajor).Union();
537 
538     flatbuffers::Offset<Operator> lstmOperator =
539                                       CreateOperator(flatBufferBuilder,
540                                                      0,
541                                                      flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
542                                                                                              operatorInputs.size()),
543                                                      flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
544                                                                                              operatorOutputs.size()),
545                                                      operatorBuiltinOptionsType, operatorBuiltinOptions);
546 
547     flatbuffers::Offset<SubGraph> subgraph =
548                                       CreateSubGraph(flatBufferBuilder,
549                                                      flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
550                                                      flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
551                                                                                              operatorInputs.size()),
552                                                      flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
553                                                                                              operatorOutputs.size()),
554                                                      flatBufferBuilder.CreateVector(&lstmOperator, 1));
555 
556     flatbuffers::Offset<flatbuffers::String> modelDescription =
557                                                  flatBufferBuilder.CreateString(
558                                                      "ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
559     flatbuffers::Offset<OperatorCode> operatorCode =
560                                                  CreateOperatorCode(flatBufferBuilder,
561                                                  tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
562 
563     flatbuffers::Offset<Model> flatbufferModel =
564                                    CreateModel(flatBufferBuilder,
565                                                TFLITE_SCHEMA_VERSION,
566                                                flatBufferBuilder.CreateVector(&operatorCode, 1),
567                                                flatBufferBuilder.CreateVector(&subgraph, 1),
568                                                modelDescription,
569                                                flatBufferBuilder.CreateVector(buffers));
570 
571     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
572 
573     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
574                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
575 }
576 
577 template<typename T>
UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId> & backends,tflite::TensorType tensorType,int32_t batchSize,int32_t timeSize,int32_t inputSize,int32_t outputSize,int32_t numUnits,bool hasInputToInputWeights,const std::vector<T> & inputToInputWeights,const std::vector<T> & inputToForgetWeights,const std::vector<T> & inputToCellWeights,const std::vector<T> & inputToOutputWeights,bool hasRecurrentToInputWeights,const std::vector<T> & recurrentToInputWeights,const std::vector<T> & recurrentToForgetWeights,const std::vector<T> & recurrentToCellWeights,const std::vector<T> & recurrentToOutputWeights,bool hasCellToInputWeights,const std::vector<T> & cellToInputWeights,bool hasCellToForgetWeights,const std::vector<T> & cellToForgetWeights,bool hasCellToOutputWeights,const std::vector<T> & cellToOutputWeights,bool hasInputGateBias,const std::vector<float> & inputGateBias,const std::vector<float> & forgetGateBias,const std::vector<float> & cellBias,const std::vector<float> & outputGateBias,bool hasProjectionWeights,const std::vector<T> & projectionWeights,bool hasProjectionBias,const std::vector<float> & projectionBias,bool hasInputLayerNormWeights,const std::vector<float> & inputLayerNormWeights,bool hasForgetLayerNormWeights,const std::vector<float> & forgetLayerNormWeights,bool hasCellLayerNormWeights,const std::vector<float> & cellLayerNormWeights,bool hasOutputLayerNormWeights,const std::vector<float> & outputLayerNormWeights,std::vector<float> & inputValues,std::vector<float> & expectedOutputValues,tflite::ActivationFunctionType activationFunction,float clippingThresCell,float clippingThresProj,bool isTimeMajor,float quantScale=0.1f)578 void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
579                                         tflite::TensorType tensorType,
580                                         int32_t batchSize,
581                                         int32_t timeSize,
582                                         int32_t inputSize,
583                                         int32_t outputSize,
584                                         int32_t numUnits,
585                                         bool hasInputToInputWeights,
586                                         const std::vector<T>& inputToInputWeights,
587                                         const std::vector<T>& inputToForgetWeights,
588                                         const std::vector<T>& inputToCellWeights,
589                                         const std::vector<T>& inputToOutputWeights,
590                                         bool hasRecurrentToInputWeights,
591                                         const std::vector<T>& recurrentToInputWeights,
592                                         const std::vector<T>& recurrentToForgetWeights,
593                                         const std::vector<T>& recurrentToCellWeights,
594                                         const std::vector<T>& recurrentToOutputWeights,
595                                         bool hasCellToInputWeights,
596                                         const std::vector<T>& cellToInputWeights,
597                                         bool hasCellToForgetWeights,
598                                         const std::vector<T>& cellToForgetWeights,
599                                         bool hasCellToOutputWeights,
600                                         const std::vector<T>& cellToOutputWeights,
601                                         bool hasInputGateBias,
602                                         const std::vector<float>& inputGateBias,
603                                         const std::vector<float>& forgetGateBias,
604                                         const std::vector<float>& cellBias,
605                                         const std::vector<float>& outputGateBias,
606                                         bool hasProjectionWeights,
607                                         const std::vector<T>& projectionWeights,
608                                         bool hasProjectionBias,
609                                         const std::vector<float>& projectionBias,
610                                         bool hasInputLayerNormWeights,
611                                         const std::vector<float>& inputLayerNormWeights,
612                                         bool hasForgetLayerNormWeights,
613                                         const std::vector<float>& forgetLayerNormWeights,
614                                         bool hasCellLayerNormWeights,
615                                         const std::vector<float>& cellLayerNormWeights,
616                                         bool hasOutputLayerNormWeights,
617                                         const std::vector<float>& outputLayerNormWeights,
618                                         std::vector<float>& inputValues,
619                                         std::vector<float>& expectedOutputValues,
620                                         tflite::ActivationFunctionType activationFunction,
621                                         float clippingThresCell,
622                                         float clippingThresProj,
623                                         bool isTimeMajor,
624                                         float quantScale = 0.1f)
625 {
626     using namespace delegateTestInterpreter;
627 
628     std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
629                                                                                 batchSize,
630                                                                                 timeSize,
631                                                                                 inputSize,
632                                                                                 outputSize,
633                                                                                 numUnits,
634                                                                                 hasInputToInputWeights,
635                                                                                 inputToInputWeights,
636                                                                                 inputToForgetWeights,
637                                                                                 inputToCellWeights,
638                                                                                 inputToOutputWeights,
639                                                                                 hasRecurrentToInputWeights,
640                                                                                 recurrentToInputWeights,
641                                                                                 recurrentToForgetWeights,
642                                                                                 recurrentToCellWeights,
643                                                                                 recurrentToOutputWeights,
644                                                                                 hasCellToInputWeights,
645                                                                                 cellToInputWeights,
646                                                                                 hasCellToForgetWeights,
647                                                                                 cellToForgetWeights,
648                                                                                 hasCellToOutputWeights,
649                                                                                 cellToOutputWeights,
650                                                                                 hasInputGateBias,
651                                                                                 inputGateBias,
652                                                                                 forgetGateBias,
653                                                                                 cellBias,
654                                                                                 outputGateBias,
655                                                                                 hasProjectionWeights,
656                                                                                 projectionWeights,
657                                                                                 hasProjectionBias,
658                                                                                 projectionBias,
659                                                                                 hasInputLayerNormWeights,
660                                                                                 inputLayerNormWeights,
661                                                                                 hasForgetLayerNormWeights,
662                                                                                 forgetLayerNormWeights,
663                                                                                 hasCellLayerNormWeights,
664                                                                                 cellLayerNormWeights,
665                                                                                 hasOutputLayerNormWeights,
666                                                                                 outputLayerNormWeights,
667                                                                                 activationFunction,
668                                                                                 clippingThresCell,
669                                                                                 clippingThresProj,
670                                                                                 isTimeMajor,
671                                                                                 quantScale);
672 
673     std::vector<int32_t> outputShape;
674     if (isTimeMajor)
675     {
676         outputShape = {timeSize, batchSize, outputSize};
677     }
678     else
679     {
680         outputShape = {batchSize, timeSize, outputSize};
681     }
682 
683     // Setup interpreter with just TFLite Runtime.
684     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
685     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
686     CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
687     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
688     std::vector<float>   tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
689     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
690 
691     // Setup interpreter with Arm NN Delegate applied.
692     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
693     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
694     CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
695     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
696     std::vector<float>   armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
697     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
698 
699     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
700 
701     if (tensorType == ::tflite::TensorType_INT8)
702     {
703         // Allow 2% tolerance for Quantized weights
704         armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
705                                    expectedOutputValues.size(), 2);
706         armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
707                                    expectedOutputValues.size(), 2);
708         armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
709                                    expectedOutputValues.size(), 2);
710     }
711     else
712     {
713         armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
714     }
715 
716     tfLiteInterpreter.Cleanup();
717     armnnInterpreter.Cleanup();
718 }
719 
720 } // anonymous namespace