xref: /aosp_15_r20/external/armnn/delegate/common/src/test/DelegateTestInterpreter.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <DelegateTestInterpreterUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendId.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/core/c/c_api.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/kernel_util.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/custom_ops_register.h>
18*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
19*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/c_api_internal.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace delegateTestInterpreter
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker class DelegateTestInterpreter
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker public:
27*89c4ff92SAndroid Build Coastguard Worker     /// Create TfLite Interpreter only
DelegateTestInterpreter(std::vector<char> & modelBuffer,const std::string & customOp="")28*89c4ff92SAndroid Build Coastguard Worker     DelegateTestInterpreter(std::vector<char>& modelBuffer, const std::string& customOp = "")
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker         TfLiteModel* model = delegateTestInterpreter::CreateTfLiteModel(modelBuffer);
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker         TfLiteInterpreterOptions* options = delegateTestInterpreter::CreateTfLiteInterpreterOptions();
33*89c4ff92SAndroid Build Coastguard Worker         if (!customOp.empty())
34*89c4ff92SAndroid Build Coastguard Worker         {
35*89c4ff92SAndroid Build Coastguard Worker             options->mutable_op_resolver = delegateTestInterpreter::GenerateCustomOpResolver(customOp);
36*89c4ff92SAndroid Build Coastguard Worker         }
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker         m_TfLiteInterpreter = TfLiteInterpreterCreate(model, options);
39*89c4ff92SAndroid Build Coastguard Worker         m_TfLiteDelegate = nullptr;
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker         // The options and model can be deleted after the interpreter is created.
42*89c4ff92SAndroid Build Coastguard Worker         TfLiteInterpreterOptionsDelete(options);
43*89c4ff92SAndroid Build Coastguard Worker         TfLiteModelDelete(model);
44*89c4ff92SAndroid Build Coastguard Worker     }
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     /// Create Interpreter with default Arm NN Classic/Opaque Delegate applied
47*89c4ff92SAndroid Build Coastguard Worker     DelegateTestInterpreter(std::vector<char>& model,
48*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::BackendId>& backends,
49*89c4ff92SAndroid Build Coastguard Worker                             const std::string& customOp = "",
50*89c4ff92SAndroid Build Coastguard Worker                             bool disableFallback = true);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     /// Create Interpreter with Arm NN Classic/Opaque Delegate applied and DelegateOptions
53*89c4ff92SAndroid Build Coastguard Worker     DelegateTestInterpreter(std::vector<char>& model,
54*89c4ff92SAndroid Build Coastguard Worker                             const armnnDelegate::DelegateOptions& delegateOptions,
55*89c4ff92SAndroid Build Coastguard Worker                             const std::string& customOp = "");
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     /// Allocate the TfLiteTensors within the graph.
58*89c4ff92SAndroid Build Coastguard Worker     /// This must be called before FillInputTensor(values, index) and Invoke().
AllocateTensors()59*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus AllocateTensors()
60*89c4ff92SAndroid Build Coastguard Worker     {
61*89c4ff92SAndroid Build Coastguard Worker         return TfLiteInterpreterAllocateTensors(m_TfLiteInterpreter);
62*89c4ff92SAndroid Build Coastguard Worker     }
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     /// Copy a buffer of values into an input tensor at a given index.
65*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
FillInputTensor(std::vector<T> & inputValues,int index)66*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus FillInputTensor(std::vector<T>& inputValues, int index)
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         TfLiteTensor* inputTensor = delegateTestInterpreter::GetInputTensorFromInterpreter(m_TfLiteInterpreter, index);
69*89c4ff92SAndroid Build Coastguard Worker         return delegateTestInterpreter::CopyFromBufferToTensor(inputTensor, inputValues);
70*89c4ff92SAndroid Build Coastguard Worker     }
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     /// Copy a boolean buffer of values into an input tensor at a given index.
73*89c4ff92SAndroid Build Coastguard Worker     /// Boolean types get converted to a bit representation in a vector.
74*89c4ff92SAndroid Build Coastguard Worker     /// vector.data() returns a void pointer instead of a pointer to bool, so the tensor needs to be accessed directly.
FillInputTensor(std::vector<bool> & inputValues,int index)75*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus FillInputTensor(std::vector<bool>& inputValues, int index)
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         TfLiteTensor* inputTensor = delegateTestInterpreter::GetInputTensorFromInterpreter(m_TfLiteInterpreter, index);
78*89c4ff92SAndroid Build Coastguard Worker         if(inputTensor->type != kTfLiteBool)
79*89c4ff92SAndroid Build Coastguard Worker         {
80*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("Input tensor at the given index is not of bool type: " + std::to_string(index));
81*89c4ff92SAndroid Build Coastguard Worker         }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker         // Make sure there is enough bytes allocated to copy into.
84*89c4ff92SAndroid Build Coastguard Worker         if(inputTensor->bytes < inputValues.size() * sizeof(bool))
85*89c4ff92SAndroid Build Coastguard Worker         {
86*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("Input tensor has not been allocated to match number of input values.");
87*89c4ff92SAndroid Build Coastguard Worker         }
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < inputValues.size(); ++i)
90*89c4ff92SAndroid Build Coastguard Worker         {
91*89c4ff92SAndroid Build Coastguard Worker             inputTensor->data.b[i] = inputValues[i];
92*89c4ff92SAndroid Build Coastguard Worker         }
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteOk;
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     /// Run the interpreter either on TFLite Runtime or Arm NN Delegate.
98*89c4ff92SAndroid Build Coastguard Worker     /// AllocateTensors() must be called before Invoke().
Invoke()99*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus Invoke()
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         return TfLiteInterpreterInvoke(m_TfLiteInterpreter);
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     /// Return a buffer of values from the output tensor at a given index.
105*89c4ff92SAndroid Build Coastguard Worker     /// This must be called after Invoke().
106*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
GetOutputResult(int index)107*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> GetOutputResult(int index)
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor* outputTensor =
110*89c4ff92SAndroid Build Coastguard Worker                 delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index);
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker         int64_t n = tflite::NumElements(outputTensor);
113*89c4ff92SAndroid Build Coastguard Worker         std::vector<T> output;
114*89c4ff92SAndroid Build Coastguard Worker         output.resize(n);
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker         TfLiteStatus status = TfLiteTensorCopyToBuffer(outputTensor, output.data(), output.size() * sizeof(T));
117*89c4ff92SAndroid Build Coastguard Worker         if(status != kTfLiteOk)
118*89c4ff92SAndroid Build Coastguard Worker         {
119*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("An error occurred when copying output buffer.");
120*89c4ff92SAndroid Build Coastguard Worker         }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker         return output;
123*89c4ff92SAndroid Build Coastguard Worker     }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     /// Return a buffer of values from the output tensor at a given index. This must be called after Invoke().
126*89c4ff92SAndroid Build Coastguard Worker     /// Boolean types get converted to a bit representation in a vector.
127*89c4ff92SAndroid Build Coastguard Worker     /// vector.data() returns a void pointer instead of a pointer to bool, so the tensor needs to be accessed directly.
GetOutputResult(int index)128*89c4ff92SAndroid Build Coastguard Worker     std::vector<bool> GetOutputResult(int index)
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor* outputTensor =
131*89c4ff92SAndroid Build Coastguard Worker                 delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index);
132*89c4ff92SAndroid Build Coastguard Worker         if(outputTensor->type != kTfLiteBool)
133*89c4ff92SAndroid Build Coastguard Worker         {
134*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("Output tensor at the given index is not of bool type: " + std::to_string(index));
135*89c4ff92SAndroid Build Coastguard Worker         }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker         int64_t n = tflite::NumElements(outputTensor);
138*89c4ff92SAndroid Build Coastguard Worker         std::vector<bool> output(n, false);
139*89c4ff92SAndroid Build Coastguard Worker         output.reserve(n);
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < output.size(); ++i)
142*89c4ff92SAndroid Build Coastguard Worker         {
143*89c4ff92SAndroid Build Coastguard Worker             output[i] = outputTensor->data.b[i];
144*89c4ff92SAndroid Build Coastguard Worker         }
145*89c4ff92SAndroid Build Coastguard Worker         return output;
146*89c4ff92SAndroid Build Coastguard Worker     }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     /// Return a buffer of dimensions from the output tensor at a given index.
GetOutputShape(int index)149*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> GetOutputShape(int index)
150*89c4ff92SAndroid Build Coastguard Worker     {
151*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor* outputTensor =
152*89c4ff92SAndroid Build Coastguard Worker                 delegateTestInterpreter::GetOutputTensorFromInterpreter(m_TfLiteInterpreter, index);
153*89c4ff92SAndroid Build Coastguard Worker         int32_t numDims = TfLiteTensorNumDims(outputTensor);
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> dims;
156*89c4ff92SAndroid Build Coastguard Worker         dims.reserve(numDims);
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker         for (int32_t i = 0; i < numDims; ++i)
159*89c4ff92SAndroid Build Coastguard Worker         {
160*89c4ff92SAndroid Build Coastguard Worker             dims.push_back(TfLiteTensorDim(outputTensor, i));
161*89c4ff92SAndroid Build Coastguard Worker         }
162*89c4ff92SAndroid Build Coastguard Worker         return dims;
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker     /// Delete TfLiteInterpreter and the TfLiteDelegate/TfLiteOpaqueDelegate
166*89c4ff92SAndroid Build Coastguard Worker     void Cleanup();
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker private:
169*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreter* m_TfLiteInterpreter;
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     /// m_TfLiteDelegate can be TfLiteDelegate or TfLiteOpaqueDelegate
172*89c4ff92SAndroid Build Coastguard Worker     void* m_TfLiteDelegate;
173*89c4ff92SAndroid Build Coastguard Worker };
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace