1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "TestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreter.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker
CreateNormalizationTfLiteModel(tflite::BuiltinOperator normalizationOperatorCode,tflite::TensorType tensorType,const std::vector<int32_t> & inputTensorShape,const std::vector<int32_t> & outputTensorShape,int32_t radius,float bias,float alpha,float beta,float quantScale=1.0f,int quantOffset=0)24*89c4ff92SAndroid Build Coastguard Worker std::vector<char> CreateNormalizationTfLiteModel(tflite::BuiltinOperator normalizationOperatorCode,
25*89c4ff92SAndroid Build Coastguard Worker tflite::TensorType tensorType,
26*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& inputTensorShape,
27*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& outputTensorShape,
28*89c4ff92SAndroid Build Coastguard Worker int32_t radius,
29*89c4ff92SAndroid Build Coastguard Worker float bias,
30*89c4ff92SAndroid Build Coastguard Worker float alpha,
31*89c4ff92SAndroid Build Coastguard Worker float beta,
32*89c4ff92SAndroid Build Coastguard Worker float quantScale = 1.0f,
33*89c4ff92SAndroid Build Coastguard Worker int quantOffset = 0)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker using namespace tflite;
36*89c4ff92SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder flatBufferBuilder;
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker auto quantizationParameters =
39*89c4ff92SAndroid Build Coastguard Worker CreateQuantizationParameters(flatBufferBuilder,
40*89c4ff92SAndroid Build Coastguard Worker 0,
41*89c4ff92SAndroid Build Coastguard Worker 0,
42*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<float>({ quantScale }),
43*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker auto inputTensor = CreateTensor(flatBufferBuilder,
46*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
47*89c4ff92SAndroid Build Coastguard Worker inputTensorShape.size()),
48*89c4ff92SAndroid Build Coastguard Worker tensorType,
49*89c4ff92SAndroid Build Coastguard Worker 1,
50*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateString("input"),
51*89c4ff92SAndroid Build Coastguard Worker quantizationParameters);
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker auto outputTensor = CreateTensor(flatBufferBuilder,
54*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
55*89c4ff92SAndroid Build Coastguard Worker outputTensorShape.size()),
56*89c4ff92SAndroid Build Coastguard Worker tensorType,
57*89c4ff92SAndroid Build Coastguard Worker 2,
58*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateString("output"),
59*89c4ff92SAndroid Build Coastguard Worker quantizationParameters);
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<Tensor>> tensors = { inputTensor, outputTensor };
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
64*89c4ff92SAndroid Build Coastguard Worker buffers.push_back(CreateBuffer(flatBufferBuilder));
65*89c4ff92SAndroid Build Coastguard Worker buffers.push_back(CreateBuffer(flatBufferBuilder));
66*89c4ff92SAndroid Build Coastguard Worker buffers.push_back(CreateBuffer(flatBufferBuilder));
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> operatorInputs = { 0 };
69*89c4ff92SAndroid Build Coastguard Worker std::vector<int> subgraphInputs = { 0 };
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_L2NormOptions;
72*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset<void> operatorBuiltinOptions = CreateL2NormOptions(flatBufferBuilder,
73*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType_NONE).Union();
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker if (normalizationOperatorCode == tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker operatorBuiltinOptionsType = BuiltinOptions_LocalResponseNormalizationOptions;
78*89c4ff92SAndroid Build Coastguard Worker operatorBuiltinOptions =
79*89c4ff92SAndroid Build Coastguard Worker CreateLocalResponseNormalizationOptions(flatBufferBuilder, radius, bias, alpha, beta).Union();
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker // create operator
83*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t> operatorOutputs{ 1 };
84*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset <Operator> normalizationOperator =
85*89c4ff92SAndroid Build Coastguard Worker CreateOperator(flatBufferBuilder,
86*89c4ff92SAndroid Build Coastguard Worker 0,
87*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
88*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
89*89c4ff92SAndroid Build Coastguard Worker operatorBuiltinOptionsType,
90*89c4ff92SAndroid Build Coastguard Worker operatorBuiltinOptions);
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker const std::vector<int> subgraphOutputs{ 1 };
93*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset <SubGraph> subgraph =
94*89c4ff92SAndroid Build Coastguard Worker CreateSubGraph(flatBufferBuilder,
95*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
96*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
97*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
98*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector(&normalizationOperator, 1));
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset <flatbuffers::String> modelDescription =
101*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateString("ArmnnDelegate: Normalization Operator Model");
102*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
103*89c4ff92SAndroid Build Coastguard Worker normalizationOperatorCode);
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Offset <Model> flatbufferModel =
106*89c4ff92SAndroid Build Coastguard Worker CreateModel(flatBufferBuilder,
107*89c4ff92SAndroid Build Coastguard Worker TFLITE_SCHEMA_VERSION,
108*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector(&operatorCode, 1),
109*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector(&subgraph, 1),
110*89c4ff92SAndroid Build Coastguard Worker modelDescription,
111*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
116*89c4ff92SAndroid Build Coastguard Worker flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker template <typename T>
NormalizationTest(tflite::BuiltinOperator normalizationOperatorCode,tflite::TensorType tensorType,const std::vector<armnn::BackendId> & backends,const std::vector<int32_t> & inputShape,std::vector<int32_t> & outputShape,std::vector<T> & inputValues,std::vector<T> & expectedOutputValues,int32_t radius=0,float bias=0.f,float alpha=0.f,float beta=0.f,float quantScale=1.0f,int quantOffset=0)120*89c4ff92SAndroid Build Coastguard Worker void NormalizationTest(tflite::BuiltinOperator normalizationOperatorCode,
121*89c4ff92SAndroid Build Coastguard Worker tflite::TensorType tensorType,
122*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId>& backends,
123*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& inputShape,
124*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t>& outputShape,
125*89c4ff92SAndroid Build Coastguard Worker std::vector<T>& inputValues,
126*89c4ff92SAndroid Build Coastguard Worker std::vector<T>& expectedOutputValues,
127*89c4ff92SAndroid Build Coastguard Worker int32_t radius = 0,
128*89c4ff92SAndroid Build Coastguard Worker float bias = 0.f,
129*89c4ff92SAndroid Build Coastguard Worker float alpha = 0.f,
130*89c4ff92SAndroid Build Coastguard Worker float beta = 0.f,
131*89c4ff92SAndroid Build Coastguard Worker float quantScale = 1.0f,
132*89c4ff92SAndroid Build Coastguard Worker int quantOffset = 0)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker using namespace delegateTestInterpreter;
135*89c4ff92SAndroid Build Coastguard Worker std::vector<char> modelBuffer = CreateNormalizationTfLiteModel(normalizationOperatorCode,
136*89c4ff92SAndroid Build Coastguard Worker tensorType,
137*89c4ff92SAndroid Build Coastguard Worker inputShape,
138*89c4ff92SAndroid Build Coastguard Worker outputShape,
139*89c4ff92SAndroid Build Coastguard Worker radius,
140*89c4ff92SAndroid Build Coastguard Worker bias,
141*89c4ff92SAndroid Build Coastguard Worker alpha,
142*89c4ff92SAndroid Build Coastguard Worker beta,
143*89c4ff92SAndroid Build Coastguard Worker quantScale,
144*89c4ff92SAndroid Build Coastguard Worker quantOffset);
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker // Setup interpreter with just TFLite Runtime.
147*89c4ff92SAndroid Build Coastguard Worker auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
148*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
149*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
150*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
151*89c4ff92SAndroid Build Coastguard Worker std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
152*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker // Setup interpreter with Arm NN Delegate applied.
155*89c4ff92SAndroid Build Coastguard Worker auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
156*89c4ff92SAndroid Build Coastguard Worker CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
157*89c4ff92SAndroid Build Coastguard Worker CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
158*89c4ff92SAndroid Build Coastguard Worker CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
159*89c4ff92SAndroid Build Coastguard Worker std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
160*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
163*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker tfLiteInterpreter.Cleanup();
166*89c4ff92SAndroid Build Coastguard Worker armnnInterpreter.Cleanup();
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker
L2NormalizationTest(std::vector<armnn::BackendId> & backends)169*89c4ff92SAndroid Build Coastguard Worker void L2NormalizationTest(std::vector<armnn::BackendId>& backends)
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker // Set input data
172*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 1, 1, 1, 10 };
173*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 1, 1, 10 };
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues
176*89c4ff92SAndroid Build Coastguard Worker {
177*89c4ff92SAndroid Build Coastguard Worker 1.0f,
178*89c4ff92SAndroid Build Coastguard Worker 2.0f,
179*89c4ff92SAndroid Build Coastguard Worker 3.0f,
180*89c4ff92SAndroid Build Coastguard Worker 4.0f,
181*89c4ff92SAndroid Build Coastguard Worker 5.0f,
182*89c4ff92SAndroid Build Coastguard Worker 6.0f,
183*89c4ff92SAndroid Build Coastguard Worker 7.0f,
184*89c4ff92SAndroid Build Coastguard Worker 8.0f,
185*89c4ff92SAndroid Build Coastguard Worker 9.0f,
186*89c4ff92SAndroid Build Coastguard Worker 10.0f
187*89c4ff92SAndroid Build Coastguard Worker };
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker const float approxInvL2Norm = 0.050964719f;
190*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker 1.0f * approxInvL2Norm,
193*89c4ff92SAndroid Build Coastguard Worker 2.0f * approxInvL2Norm,
194*89c4ff92SAndroid Build Coastguard Worker 3.0f * approxInvL2Norm,
195*89c4ff92SAndroid Build Coastguard Worker 4.0f * approxInvL2Norm,
196*89c4ff92SAndroid Build Coastguard Worker 5.0f * approxInvL2Norm,
197*89c4ff92SAndroid Build Coastguard Worker 6.0f * approxInvL2Norm,
198*89c4ff92SAndroid Build Coastguard Worker 7.0f * approxInvL2Norm,
199*89c4ff92SAndroid Build Coastguard Worker 8.0f * approxInvL2Norm,
200*89c4ff92SAndroid Build Coastguard Worker 9.0f * approxInvL2Norm,
201*89c4ff92SAndroid Build Coastguard Worker 10.0f * approxInvL2Norm
202*89c4ff92SAndroid Build Coastguard Worker };
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker NormalizationTest<float>(tflite::BuiltinOperator_L2_NORMALIZATION,
205*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
206*89c4ff92SAndroid Build Coastguard Worker backends,
207*89c4ff92SAndroid Build Coastguard Worker inputShape,
208*89c4ff92SAndroid Build Coastguard Worker outputShape,
209*89c4ff92SAndroid Build Coastguard Worker inputValues,
210*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker
LocalResponseNormalizationTest(std::vector<armnn::BackendId> & backends,int32_t radius,float bias,float alpha,float beta)213*89c4ff92SAndroid Build Coastguard Worker void LocalResponseNormalizationTest(std::vector<armnn::BackendId>& backends,
214*89c4ff92SAndroid Build Coastguard Worker int32_t radius,
215*89c4ff92SAndroid Build Coastguard Worker float bias,
216*89c4ff92SAndroid Build Coastguard Worker float alpha,
217*89c4ff92SAndroid Build Coastguard Worker float beta)
218*89c4ff92SAndroid Build Coastguard Worker {
219*89c4ff92SAndroid Build Coastguard Worker // Set input data
220*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 2, 2, 2, 1 };
221*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2, 2, 2, 1 };
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f,
226*89c4ff92SAndroid Build Coastguard Worker 3.0f, 4.0f,
227*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f,
228*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f
229*89c4ff92SAndroid Build Coastguard Worker };
230*89c4ff92SAndroid Build Coastguard Worker
231*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues
232*89c4ff92SAndroid Build Coastguard Worker {
233*89c4ff92SAndroid Build Coastguard Worker 0.5f, 0.400000006f, 0.300000012f, 0.235294119f,
234*89c4ff92SAndroid Build Coastguard Worker 0.192307696f, 0.16216217f, 0.140000001f, 0.123076923f
235*89c4ff92SAndroid Build Coastguard Worker };
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker NormalizationTest<float>(tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
238*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
239*89c4ff92SAndroid Build Coastguard Worker backends,
240*89c4ff92SAndroid Build Coastguard Worker inputShape,
241*89c4ff92SAndroid Build Coastguard Worker outputShape,
242*89c4ff92SAndroid Build Coastguard Worker inputValues,
243*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
244*89c4ff92SAndroid Build Coastguard Worker radius,
245*89c4ff92SAndroid Build Coastguard Worker bias,
246*89c4ff92SAndroid Build Coastguard Worker alpha,
247*89c4ff92SAndroid Build Coastguard Worker beta);
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace