xref: /aosp_15_r20/external/armnn/delegate/test/ArgMinMaxTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12 
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16 
17 #include <schema_generated.h>
18 
19 #include <doctest/doctest.h>
20 
21 namespace
22 {
23 
24 template <typename InputT, typename OutputT>
CreateArgMinMaxTfLiteModel(tflite::BuiltinOperator argMinMaxOperatorCode,tflite::TensorType tensorType,const std::vector<int32_t> & inputTensorShape,const std::vector<int32_t> & axisTensorShape,const std::vector<int32_t> & outputTensorShape,const std::vector<OutputT> axisValue,tflite::TensorType outputType,float quantScale=1.0f,int quantOffset=0)25 std::vector<char> CreateArgMinMaxTfLiteModel(tflite::BuiltinOperator argMinMaxOperatorCode,
26                                              tflite::TensorType tensorType,
27                                              const std::vector<int32_t>& inputTensorShape,
28                                              const std::vector<int32_t>& axisTensorShape,
29                                              const std::vector<int32_t>& outputTensorShape,
30                                              const std::vector<OutputT> axisValue,
31                                              tflite::TensorType outputType,
32                                              float quantScale = 1.0f,
33                                              int quantOffset  = 0)
34 {
35     using namespace tflite;
36     flatbuffers::FlatBufferBuilder flatBufferBuilder;
37 
38     auto quantizationParameters =
39         CreateQuantizationParameters(flatBufferBuilder,
40                                      0,
41                                      0,
42                                      flatBufferBuilder.CreateVector<float>({ quantScale }),
43                                      flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
44 
45     auto inputTensor = CreateTensor(flatBufferBuilder,
46                                     flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
47                                                                             inputTensorShape.size()),
48                                     tensorType,
49                                     1,
50                                     flatBufferBuilder.CreateString("input"),
51                                     quantizationParameters);
52 
53     auto axisTensor = CreateTensor(flatBufferBuilder,
54                                    flatBufferBuilder.CreateVector<int32_t>(axisTensorShape.data(),
55                                                                            axisTensorShape.size()),
56                                    tflite::TensorType_INT32,
57                                    2,
58                                    flatBufferBuilder.CreateString("axis"));
59 
60     auto outputTensor = CreateTensor(flatBufferBuilder,
61                                      flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
62                                                                              outputTensorShape.size()),
63                                      outputType,
64                                      3,
65                                      flatBufferBuilder.CreateString("output"),
66                                      quantizationParameters);
67 
68     std::vector<flatbuffers::Offset<Tensor>> tensors = { inputTensor, axisTensor, outputTensor };
69 
70     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
71     buffers.push_back(CreateBuffer(flatBufferBuilder));
72     buffers.push_back(CreateBuffer(flatBufferBuilder));
73     buffers.push_back(
74         CreateBuffer(flatBufferBuilder,
75                      flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisValue.data()),
76                                                     sizeof(OutputT))));
77     buffers.push_back(CreateBuffer(flatBufferBuilder));
78 
79     std::vector<int32_t> operatorInputs = {{ 0, 1 }};
80     std::vector<int> subgraphInputs = {{ 0, 1 }};
81 
82     tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_ArgMaxOptions;
83     flatbuffers::Offset<void> operatorBuiltinOptions = CreateArgMaxOptions(flatBufferBuilder, outputType).Union();
84 
85     if (argMinMaxOperatorCode == tflite::BuiltinOperator_ARG_MIN)
86     {
87         operatorBuiltinOptionsType = BuiltinOptions_ArgMinOptions;
88         operatorBuiltinOptions = CreateArgMinOptions(flatBufferBuilder, outputType).Union();
89     }
90 
91     // create operator
92     const std::vector<int32_t> operatorOutputs{ 2 };
93     flatbuffers::Offset <Operator> argMinMaxOperator =
94         CreateOperator(flatBufferBuilder,
95                        0,
96                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
97                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
98                        operatorBuiltinOptionsType,
99                        operatorBuiltinOptions);
100 
101     const std::vector<int> subgraphOutputs{ 2 };
102     flatbuffers::Offset <SubGraph> subgraph =
103         CreateSubGraph(flatBufferBuilder,
104                        flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
105                        flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
106                        flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
107                        flatBufferBuilder.CreateVector(&argMinMaxOperator, 1));
108 
109     flatbuffers::Offset <flatbuffers::String> modelDescription =
110         flatBufferBuilder.CreateString("ArmnnDelegate: ArgMinMax Operator Model");
111     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
112                                                                          argMinMaxOperatorCode);
113 
114     flatbuffers::Offset <Model> flatbufferModel =
115         CreateModel(flatBufferBuilder,
116                     TFLITE_SCHEMA_VERSION,
117                     flatBufferBuilder.CreateVector(&operatorCode, 1),
118                     flatBufferBuilder.CreateVector(&subgraph, 1),
119                     modelDescription,
120                     flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
121 
122     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
123 
124     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
125                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
126 }
127 
128 template <typename InputT, typename OutputT>
ArgMinMaxTest(tflite::BuiltinOperator argMinMaxOperatorCode,tflite::TensorType tensorType,const std::vector<armnn::BackendId> & backends,const std::vector<int32_t> & inputShape,const std::vector<int32_t> & axisShape,std::vector<int32_t> & outputShape,std::vector<InputT> & inputValues,std::vector<OutputT> & expectedOutputValues,OutputT axisValue,tflite::TensorType outputType,float quantScale=1.0f,int quantOffset=0)129 void ArgMinMaxTest(tflite::BuiltinOperator argMinMaxOperatorCode,
130                    tflite::TensorType tensorType,
131                    const std::vector<armnn::BackendId>& backends,
132                    const std::vector<int32_t>& inputShape,
133                    const std::vector<int32_t>& axisShape,
134                    std::vector<int32_t>& outputShape,
135                    std::vector<InputT>& inputValues,
136                    std::vector<OutputT>& expectedOutputValues,
137                    OutputT axisValue,
138                    tflite::TensorType outputType,
139                    float quantScale = 1.0f,
140                    int quantOffset  = 0)
141 {
142     using namespace delegateTestInterpreter;
143     std::vector<char> modelBuffer = CreateArgMinMaxTfLiteModel<InputT, OutputT>(argMinMaxOperatorCode,
144                                                                                 tensorType,
145                                                                                 inputShape,
146                                                                                 axisShape,
147                                                                                 outputShape,
148                                                                                 {axisValue},
149                                                                                 outputType,
150                                                                                 quantScale,
151                                                                                 quantOffset);
152 
153     // Setup interpreter with just TFLite Runtime.
154     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
155     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
156     CHECK(tfLiteInterpreter.FillInputTensor<InputT>(inputValues, 0) == kTfLiteOk);
157     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
158     std::vector<OutputT> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<OutputT>(0);
159     std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(0);
160 
161     // Setup interpreter with Arm NN Delegate applied.
162     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
163     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
164     CHECK(armnnInterpreter.FillInputTensor<InputT>(inputValues, 0) == kTfLiteOk);
165     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
166     std::vector<OutputT> armnnOutputValues = armnnInterpreter.GetOutputResult<OutputT>(0);
167     std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(0);
168 
169     armnnDelegate::CompareOutputData<OutputT>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
170     armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
171 
172     tfLiteInterpreter.Cleanup();
173     armnnInterpreter.Cleanup();
174 }
175 
176 } // anonymous namespace