xref: /aosp_15_r20/external/armnn/delegate/test/ShapeTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12 
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16 
17 #include <schema_generated.h>
18 
19 #include <doctest/doctest.h>
20 
21 namespace
22 {
CreateShapeTfLiteModel(tflite::TensorType inputTensorType,tflite::TensorType outputTensorType,const std::vector<int32_t> & inputTensorShape,const std::vector<int32_t> & outputTensorShape,float quantScale=1.0f,int quantOffset=0)23 std::vector<char> CreateShapeTfLiteModel(tflite::TensorType inputTensorType,
24                                          tflite::TensorType outputTensorType,
25                                          const std::vector<int32_t>& inputTensorShape,
26                                          const std::vector<int32_t>& outputTensorShape,
27                                          float quantScale = 1.0f,
28                                          int quantOffset = 0)
29 {
30     using namespace tflite;
31     flatbuffers::FlatBufferBuilder flatBufferBuilder;
32 
33     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
34     buffers.push_back(CreateBuffer(flatBufferBuilder));
35     buffers.push_back(CreateBuffer(flatBufferBuilder));
36     buffers.push_back(CreateBuffer(flatBufferBuilder));
37 
38     auto quantizationParameters =
39              CreateQuantizationParameters(flatBufferBuilder,
40                                           0,
41                                           0,
42                                           flatBufferBuilder.CreateVector<float>({ quantScale }),
43                                           flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
44 
45     std::array<flatbuffers::Offset<Tensor>, 2> tensors;
46     tensors[0] = CreateTensor(flatBufferBuilder,
47                               flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
48                                                                       inputTensorShape.size()),
49                               inputTensorType,
50                               1,
51                               flatBufferBuilder.CreateString("input"),
52                               quantizationParameters);
53     tensors[1] = CreateTensor(flatBufferBuilder,
54                               flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
55                                                                       outputTensorShape.size()),
56                               outputTensorType,
57                               2,
58                               flatBufferBuilder.CreateString("output"),
59                               quantizationParameters);
60 
61     const std::vector<int32_t> operatorInputs({ 0 });
62     const std::vector<int32_t> operatorOutputs({ 1 });
63 
64     flatbuffers::Offset<Operator> shapeOperator =
65                                       CreateOperator(flatBufferBuilder,
66                                                      0,
67                                                      flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
68                                                                                              operatorInputs.size()),
69                                                      flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
70                                                                                              operatorOutputs.size()),
71                                                      BuiltinOptions_ShapeOptions,
72                                                      CreateShapeOptions(flatBufferBuilder, outputTensorType).Union());
73 
74     flatbuffers::Offset<flatbuffers::String> modelDescription =
75         flatBufferBuilder.CreateString("ArmnnDelegate: SHAPE Operator Model");
76 
77     flatbuffers::Offset<OperatorCode> operatorCode =
78         CreateOperatorCode(flatBufferBuilder, tflite::BuiltinOperator_SHAPE);
79 
80     const std::vector<int32_t>    subgraphInputs({ 0 });
81     const std::vector<int32_t>    subgraphOutputs({ 1 });
82 
83     flatbuffers::Offset<SubGraph> subgraph =
84         CreateSubGraph(flatBufferBuilder,
85                        flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
86                        flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(),
87                                                                subgraphInputs.size()),
88                        flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(),
89                                                                subgraphOutputs.size()),
90                        flatBufferBuilder.CreateVector(&shapeOperator, 1));
91 
92     flatbuffers::Offset<Model> flatbufferModel =
93         CreateModel(flatBufferBuilder,
94                     TFLITE_SCHEMA_VERSION,
95                     flatBufferBuilder.CreateVector(&operatorCode, 1),
96                     flatBufferBuilder.CreateVector(&subgraph, 1),
97                     modelDescription,
98                     flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
99 
100     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
101 
102     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
103                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
104 }
105 
106 template<typename T, typename K>
ShapeTest(tflite::TensorType inputTensorType,tflite::TensorType outputTensorType,std::vector<armnn::BackendId> & backends,std::vector<int32_t> & inputShape,std::vector<T> & inputValues,std::vector<K> & expectedOutputValues,std::vector<int32_t> & expectedOutputShape,float quantScale=1.0f,int quantOffset=0)107 void ShapeTest(tflite::TensorType inputTensorType,
108                tflite::TensorType outputTensorType,
109                std::vector<armnn::BackendId>& backends,
110                std::vector<int32_t>& inputShape,
111                std::vector<T>& inputValues,
112                std::vector<K>& expectedOutputValues,
113                std::vector<int32_t>& expectedOutputShape,
114                float quantScale = 1.0f,
115                int quantOffset = 0)
116 {
117     using namespace delegateTestInterpreter;
118     std::vector<char> modelBuffer = CreateShapeTfLiteModel(inputTensorType,
119                                                            outputTensorType,
120                                                            inputShape,
121                                                            expectedOutputShape,
122                                                            quantScale,
123                                                            quantOffset);
124 
125     // Setup interpreter with just TFLite Runtime.
126     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
127     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
128     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
129     std::vector<K>       tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<K>(0);
130     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
131 
132     // Setup interpreter with Arm NN Delegate applied.
133     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
134     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
135     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
136     std::vector<K>       armnnOutputValues = armnnInterpreter.GetOutputResult<K>(0);
137     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
138 
139     armnnDelegate::CompareOutputData<K>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
140     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
141 
142     tfLiteInterpreter.Cleanup();
143     armnnInterpreter.Cleanup();
144 }
145 
146 } // anonymous namespace
147